mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
Compare commits
4 Commits
feat-bring
...
aesgi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
107eb56b1d | ||
|
|
5c07f5f340 | ||
|
|
7ad78d4219 | ||
|
|
a5153d5212 |
16
AGENTS.md
16
AGENTS.md
@@ -37,6 +37,22 @@ Run the Flask API (if needed):
|
||||
flask --app application/app.py run --host=0.0.0.0 --port=7091
|
||||
```
|
||||
|
||||
That's the fast inner-loop option — quick startup, the Werkzeug interactive
|
||||
debugger still works, and it hot-reloads on source changes. It serves the
|
||||
Flask routes only (`/api/*`, `/stream`, etc.).
|
||||
|
||||
If you need to exercise the full ASGI stack — the `/mcp` FastMCP endpoint,
|
||||
or to match the production runtime exactly — run the ASGI composition under
|
||||
uvicorn instead:
|
||||
|
||||
```bash
|
||||
uvicorn application.asgi:asgi_app --host 0.0.0.0 --port 7091 --reload
|
||||
```
|
||||
|
||||
Production uses `gunicorn -k uvicorn_worker.UvicornWorker` against the same
|
||||
`application.asgi:asgi_app` target; see `application/Dockerfile` for the
|
||||
full flag set.
|
||||
|
||||
Run the Celery worker in a separate terminal (if needed):
|
||||
|
||||
```bash
|
||||
|
||||
@@ -88,5 +88,15 @@ EXPOSE 7091
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Start Gunicorn
|
||||
CMD ["gunicorn", "-w", "1", "--timeout", "120", "--bind", "0.0.0.0:7091", "--preload", "application.wsgi:app"]
|
||||
CMD ["gunicorn", \
|
||||
"-w", "1", \
|
||||
"-k", "uvicorn_worker.UvicornWorker", \
|
||||
"--bind", "0.0.0.0:7091", \
|
||||
"--timeout", "180", \
|
||||
"--graceful-timeout", "120", \
|
||||
"--keep-alive", "5", \
|
||||
"--worker-tmp-dir", "/dev/shm", \
|
||||
"--max-requests", "1000", \
|
||||
"--max-requests-jitter", "100", \
|
||||
"--config", "application/gunicorn_conf.py", \
|
||||
"application.asgi:asgi_app"]
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import fields, Resource
|
||||
|
||||
from application.api.answer.routes.base import answer_ns
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
from application.services.search_service import (
|
||||
InvalidAPIKey,
|
||||
SearchFailed,
|
||||
search,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@answer_ns.route("/api/search")
|
||||
class SearchResource(Resource):
|
||||
"""Fast search endpoint for retrieving relevant documents"""
|
||||
"""Fast search endpoint for retrieving relevant documents."""
|
||||
|
||||
search_model = answer_ns.model(
|
||||
"SearchModel",
|
||||
@@ -32,102 +32,10 @@ class SearchResource(Resource):
|
||||
},
|
||||
)
|
||||
|
||||
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
|
||||
"""Get source IDs connected to the API key/agent."""
|
||||
with db_readonly() as conn:
|
||||
agent_data = AgentsRepository(conn).find_by_key(api_key)
|
||||
if not agent_data:
|
||||
return []
|
||||
|
||||
source_ids: List[str] = []
|
||||
# extra_source_ids is a PG ARRAY(UUID) of source UUIDs.
|
||||
extra = agent_data.get("extra_source_ids") or []
|
||||
for src in extra:
|
||||
if src:
|
||||
source_ids.append(str(src))
|
||||
|
||||
if not source_ids:
|
||||
single = agent_data.get("source_id")
|
||||
if single:
|
||||
source_ids.append(str(single))
|
||||
|
||||
return source_ids
|
||||
|
||||
def _search_vectorstores(
|
||||
self, query: str, source_ids: List[str], chunks: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search across vectorstores and return results"""
|
||||
if not source_ids:
|
||||
return []
|
||||
|
||||
results = []
|
||||
chunks_per_source = max(1, chunks // len(source_ids))
|
||||
seen_texts = set()
|
||||
|
||||
for source_id in source_ids:
|
||||
if not source_id or not source_id.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
docsearch = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
|
||||
)
|
||||
docs = docsearch.search(query, k=chunks_per_source * 2)
|
||||
|
||||
for doc in docs:
|
||||
if len(results) >= chunks:
|
||||
break
|
||||
|
||||
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
||||
page_content = doc.page_content
|
||||
metadata = doc.metadata
|
||||
else:
|
||||
page_content = doc.get("text", doc.get("page_content", ""))
|
||||
metadata = doc.get("metadata", {})
|
||||
|
||||
# Skip duplicates
|
||||
text_hash = hash(page_content[:200])
|
||||
if text_hash in seen_texts:
|
||||
continue
|
||||
seen_texts.add(text_hash)
|
||||
|
||||
title = metadata.get(
|
||||
"title", metadata.get("post_title", "")
|
||||
)
|
||||
if not isinstance(title, str):
|
||||
title = str(title) if title else ""
|
||||
|
||||
# Clean up title
|
||||
if title:
|
||||
title = title.split("/")[-1]
|
||||
else:
|
||||
# Use filename or first part of content as title
|
||||
title = metadata.get("filename", page_content[:50] + "...")
|
||||
|
||||
source = metadata.get("source", source_id)
|
||||
|
||||
results.append({
|
||||
"text": page_content,
|
||||
"title": title,
|
||||
"source": source,
|
||||
})
|
||||
|
||||
if len(results) >= chunks:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error searching vectorstore {source_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
return results[:chunks]
|
||||
|
||||
@answer_ns.expect(search_model)
|
||||
@answer_ns.doc(description="Search for relevant documents based on query")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
data = request.get_json() or {}
|
||||
|
||||
question = data.get("question")
|
||||
api_key = data.get("api_key")
|
||||
@@ -135,32 +43,13 @@ class SearchResource(Resource):
|
||||
|
||||
if not question:
|
||||
return make_response({"error": "question is required"}, 400)
|
||||
|
||||
if not api_key:
|
||||
return make_response({"error": "api_key is required"}, 400)
|
||||
|
||||
# Validate API key
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
if not agent:
|
||||
return make_response({"error": "Invalid API key"}, 401)
|
||||
|
||||
try:
|
||||
# Get sources connected to this API key
|
||||
source_ids = self._get_sources_from_api_key(api_key)
|
||||
|
||||
if not source_ids:
|
||||
return make_response([], 200)
|
||||
|
||||
# Perform search
|
||||
results = self._search_vectorstores(question, source_ids, chunks)
|
||||
|
||||
return make_response(results, 200)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/search - error: {str(e)}",
|
||||
extra={"error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(search(api_key, question, chunks), 200)
|
||||
except InvalidAPIKey:
|
||||
return make_response({"error": "Invalid API key"}, 401)
|
||||
except SearchFailed:
|
||||
logger.exception("/api/search failed")
|
||||
return make_response({"error": "Search failed"}, 500)
|
||||
|
||||
@@ -4,7 +4,7 @@ import platform
|
||||
import uuid
|
||||
|
||||
import dotenv
|
||||
from flask import Flask, jsonify, redirect, request
|
||||
from flask import Flask, Response, jsonify, redirect, request
|
||||
from jose import jwt
|
||||
|
||||
from application.auth import handle_auth
|
||||
@@ -149,12 +149,11 @@ def authenticate_request():
|
||||
|
||||
|
||||
@app.after_request
|
||||
def after_request(response):
|
||||
response.headers.add("Access-Control-Allow-Origin", "*")
|
||||
response.headers.add("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
response.headers.add(
|
||||
"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
|
||||
)
|
||||
def after_request(response: Response) -> Response:
|
||||
"""Add CORS headers for the pure Flask development entrypoint."""
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
|
||||
return response
|
||||
|
||||
|
||||
|
||||
33
application/asgi.py
Normal file
33
application/asgi.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""ASGI entrypoint: Flask (WSGI) + FastMCP on the same process."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from a2wsgi import WSGIMiddleware
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.routing import Mount
|
||||
|
||||
from application.app import app as flask_app
|
||||
from application.mcp_server import mcp
|
||||
|
||||
_WSGI_THREADPOOL = 32
|
||||
|
||||
mcp_app = mcp.http_app(path="/")
|
||||
|
||||
asgi_app = Starlette(
|
||||
routes=[
|
||||
Mount("/mcp", app=mcp_app),
|
||||
Mount("/", app=WSGIMiddleware(flask_app, workers=_WSGI_THREADPOOL)),
|
||||
],
|
||||
middleware=[
|
||||
Middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization", "Mcp-Session-Id"],
|
||||
expose_headers=["Mcp-Session-Id"],
|
||||
),
|
||||
],
|
||||
lifespan=mcp_app.lifespan,
|
||||
)
|
||||
72
application/gunicorn_conf.py
Normal file
72
application/gunicorn_conf.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Gunicorn config — keeps uvicorn's access log in NCSA format."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
|
||||
# NCSA common log format:
|
||||
# %(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"
|
||||
# Uvicorn's access formatter exposes a ``client_addr``/``request_line``/
|
||||
# ``status_code`` trio but not the full NCSA field set, so we re-derive
|
||||
# what we can.
|
||||
_NCSA_FMT = (
|
||||
'%(client_addr)s - - [%(asctime)s] "%(request_line)s" %(status_code)s'
|
||||
)
|
||||
|
||||
logconfig_dict = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"ncsa_access": {
|
||||
"()": "uvicorn.logging.AccessFormatter",
|
||||
"fmt": _NCSA_FMT,
|
||||
"datefmt": "%d/%b/%Y:%H:%M:%S %z",
|
||||
"use_colors": False,
|
||||
},
|
||||
"default": {
|
||||
"format": "[%(asctime)s] [%(process)d] [%(levelname)s] %(name)s: %(message)s",
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"access": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "ncsa_access",
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
"default": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "default",
|
||||
"stream": "ext://sys.stderr",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"uvicorn": {"handlers": ["default"], "level": "INFO", "propagate": False},
|
||||
"uvicorn.error": {
|
||||
"handlers": ["default"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
},
|
||||
"uvicorn.access": {
|
||||
"handlers": ["access"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
},
|
||||
"gunicorn.error": {
|
||||
"handlers": ["default"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
},
|
||||
"gunicorn.access": {
|
||||
"handlers": ["access"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
"root": {"handlers": ["default"], "level": "INFO"},
|
||||
}
|
||||
|
||||
|
||||
def on_starting(server): # pragma: no cover — gunicorn hook
|
||||
"""Ensure gunicorn's own loggers use the configured handlers."""
|
||||
logging.config.dictConfig(logconfig_dict)
|
||||
59
application/mcp_server.py
Normal file
59
application/mcp_server.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""FastMCP server exposing DocsGPT retrieval over streamable HTTP.
|
||||
|
||||
Mounted at ``/mcp`` by ``application/asgi.py``. Bearer tokens are the
|
||||
existing DocsGPT agent API keys — no new credential surface.
|
||||
|
||||
The tool reads the ``Authorization`` header directly via
|
||||
``get_http_headers(include={"authorization"})``. The ``include`` kwarg
|
||||
is required: by default ``get_http_headers`` strips ``authorization``
|
||||
(and a handful of other hop-by-hop headers) so they aren't forwarded
|
||||
to downstream services — since we deliberately want the caller's
|
||||
token, we opt it back in.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.server.dependencies import get_http_headers
|
||||
|
||||
from application.services.search_service import (
|
||||
InvalidAPIKey,
|
||||
SearchFailed,
|
||||
search,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
mcp = FastMCP("docsgpt")
|
||||
|
||||
|
||||
def _extract_bearer_token() -> str | None:
|
||||
auth = get_http_headers(include={"authorization"}).get("authorization", "")
|
||||
parts = auth.split(None, 1)
|
||||
if len(parts) != 2 or parts[0].lower() != "bearer" or not parts[1]:
|
||||
return None
|
||||
return parts[1]
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def search_docs(query: str, chunks: int = 5) -> list[dict]:
|
||||
"""Search the caller's DocsGPT knowledge base.
|
||||
|
||||
Authentication is via ``Authorization: Bearer <agent-api-key>`` on
|
||||
the MCP request — the same opaque key that ``/api/search`` accepts
|
||||
in its JSON body. Returns at most ``chunks`` hits, each a dict with
|
||||
``text``, ``title``, ``source`` keys.
|
||||
"""
|
||||
api_key = _extract_bearer_token()
|
||||
if not api_key:
|
||||
raise PermissionError("Missing Bearer token")
|
||||
try:
|
||||
return await asyncio.to_thread(search, api_key, query, chunks)
|
||||
except InvalidAPIKey as exc:
|
||||
raise PermissionError("Invalid API key") from exc
|
||||
except SearchFailed:
|
||||
logger.exception("search_docs failed")
|
||||
raise
|
||||
@@ -1,5 +1,7 @@
|
||||
a2wsgi==1.10.10
|
||||
alembic>=1.13,<2
|
||||
anthropic==0.88.0
|
||||
asgiref>=3.11.1
|
||||
boto3==1.42.83
|
||||
beautifulsoup4==4.14.3
|
||||
cel-python==0.5.0
|
||||
@@ -14,7 +16,7 @@ docx2txt==0.9
|
||||
ddgs>=8.0.0
|
||||
fast-ebook
|
||||
elevenlabs==2.43.0
|
||||
Flask==3.1.3
|
||||
Flask==3.1.1
|
||||
faiss-cpu==1.13.2
|
||||
fastmcp==3.2.4
|
||||
flask-restx==1.3.2
|
||||
@@ -76,6 +78,7 @@ requests==2.33.1
|
||||
retry==0.9.2
|
||||
sentence-transformers==5.3.0
|
||||
sqlalchemy>=2.0,<3
|
||||
starlette>=1.0,<2
|
||||
tiktoken==0.12.0
|
||||
tokenizers==0.22.2
|
||||
torch==2.11.0
|
||||
@@ -85,6 +88,8 @@ typing-extensions==4.15.0
|
||||
typing-inspect==0.9.0
|
||||
tzdata==2026.1
|
||||
urllib3==2.6.3
|
||||
uvicorn[standard]>=0.30,<1
|
||||
uvicorn-worker>=0.4,<1
|
||||
vine==5.1.0
|
||||
wcwidth==0.6.0
|
||||
werkzeug>=3.1.0
|
||||
|
||||
0
application/services/__init__.py
Normal file
0
application/services/__init__.py
Normal file
153
application/services/search_service.py
Normal file
153
application/services/search_service.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Shared retrieval service used by the HTTP search route and the MCP tool.
|
||||
|
||||
Flask-free. Raises domain exceptions (``InvalidAPIKey``, ``SearchFailed``)
|
||||
that callers translate into their own wire protocol (HTTP status codes,
|
||||
MCP error responses, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InvalidAPIKey(Exception):
|
||||
"""The supplied ``api_key`` does not resolve to an agent."""
|
||||
|
||||
|
||||
class SearchFailed(Exception):
|
||||
"""Unexpected error during retrieval (e.g. DB outage). Caller maps to 5xx."""
|
||||
|
||||
|
||||
def _collect_source_ids(agent: Dict[str, Any]) -> List[str]:
|
||||
"""Extract the ordered list of source UUIDs to search.
|
||||
|
||||
Prefers ``extra_source_ids`` (PG ARRAY(UUID) of multi-source agents);
|
||||
falls back to the legacy single ``source_id`` field.
|
||||
"""
|
||||
source_ids: List[str] = []
|
||||
extra = agent.get("extra_source_ids") or []
|
||||
for src in extra:
|
||||
if src:
|
||||
source_ids.append(str(src))
|
||||
if not source_ids:
|
||||
single = agent.get("source_id")
|
||||
if single:
|
||||
source_ids.append(str(single))
|
||||
return source_ids
|
||||
|
||||
|
||||
def _search_sources(
|
||||
query: str, source_ids: List[str], chunks: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search across each source's vectorstore and return up to ``chunks`` hits.
|
||||
|
||||
Per-source errors are logged and skipped so one broken index doesn't
|
||||
take down the whole search. Results are de-duplicated by content hash.
|
||||
"""
|
||||
if chunks <= 0 or not source_ids:
|
||||
return []
|
||||
|
||||
results: List[Dict[str, Any]] = []
|
||||
chunks_per_source = max(1, chunks // len(source_ids))
|
||||
seen_texts: set[int] = set()
|
||||
|
||||
for source_id in source_ids:
|
||||
if not source_id or not source_id.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
docsearch = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
|
||||
)
|
||||
docs = docsearch.search(query, k=chunks_per_source * 2)
|
||||
|
||||
for doc in docs:
|
||||
if len(results) >= chunks:
|
||||
break
|
||||
|
||||
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
||||
page_content = doc.page_content
|
||||
metadata = doc.metadata
|
||||
else:
|
||||
page_content = doc.get("text", doc.get("page_content", ""))
|
||||
metadata = doc.get("metadata", {})
|
||||
|
||||
text_hash = hash(page_content[:200])
|
||||
if text_hash in seen_texts:
|
||||
continue
|
||||
seen_texts.add(text_hash)
|
||||
|
||||
title = metadata.get("title", metadata.get("post_title", ""))
|
||||
if not isinstance(title, str):
|
||||
title = str(title) if title else ""
|
||||
|
||||
if title:
|
||||
title = title.split("/")[-1]
|
||||
else:
|
||||
title = metadata.get("filename", page_content[:50] + "...")
|
||||
|
||||
source = metadata.get("source", source_id)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"text": page_content,
|
||||
"title": title,
|
||||
"source": source,
|
||||
}
|
||||
)
|
||||
|
||||
if len(results) >= chunks:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error searching vectorstore {source_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
return results[:chunks]
|
||||
|
||||
|
||||
def search(api_key: str, query: str, chunks: int = 5) -> List[Dict[str, Any]]:
|
||||
"""Resolve an agent by API key and search its sources.
|
||||
|
||||
Args:
|
||||
api_key: Agent API key (the opaque string stored on
|
||||
``agents.key`` in Postgres).
|
||||
query: Free-text search query.
|
||||
chunks: Max number of hits to return.
|
||||
|
||||
Returns:
|
||||
List of hit dicts with ``text``, ``title``, ``source`` keys.
|
||||
Empty list if the agent has no sources configured.
|
||||
|
||||
Raises:
|
||||
InvalidAPIKey: if ``api_key`` does not resolve to an agent.
|
||||
SearchFailed: on unexpected DB / infrastructure errors.
|
||||
"""
|
||||
if chunks <= 0:
|
||||
return []
|
||||
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
except Exception as e:
|
||||
raise SearchFailed("agent lookup failed") from e
|
||||
|
||||
if not agent:
|
||||
raise InvalidAPIKey()
|
||||
|
||||
source_ids = _collect_source_ids(agent)
|
||||
if not source_ids:
|
||||
return []
|
||||
|
||||
return _search_sources(query, source_ids, chunks)
|
||||
@@ -104,7 +104,15 @@ To run the DocsGPT backend locally, you'll need to set up a Python environment a
|
||||
flask --app application/app.py run --host=0.0.0.0 --port=7091
|
||||
```
|
||||
|
||||
This command will launch the backend server, making it accessible on `http://localhost:7091`.
|
||||
This command will launch the backend server, making it accessible on `http://localhost:7091`. It's the fastest inner-loop option for day-to-day development — the Werkzeug interactive debugger still works and it hot-reloads on source changes. It serves the Flask routes only.
|
||||
|
||||
If you need to exercise the full ASGI stack — the `/mcp` endpoint (FastMCP server), or to match the production runtime — run the ASGI composition under uvicorn instead:
|
||||
|
||||
```bash
|
||||
uvicorn application.asgi:asgi_app --host 0.0.0.0 --port 7091 --reload
|
||||
```
|
||||
|
||||
Production uses `gunicorn -k uvicorn_worker.UvicornWorker` against the same `application.asgi:asgi_app` target.
|
||||
|
||||
6. **Start the Celery Worker:**
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ markers =
|
||||
unit: Unit tests
|
||||
integration: Integration tests
|
||||
slow: Slow running tests
|
||||
asyncio_mode = strict
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::PendingDeprecationWarning
|
||||
|
||||
137
scripts/mock_llm.py
Normal file
137
scripts/mock_llm.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Mock OpenAI-compatible LLM server for benchmarking.
|
||||
|
||||
Fixed 5-second generation (100 tokens × 50 ms/token). No auth. Emits SSE
|
||||
chunks in OpenAI's chat.completions streaming format, or a single response
|
||||
when stream=false. Run on 127.0.0.1:8090 — point DocsGPT at it via
|
||||
OPENAI_BASE_URL=http://127.0.0.1:8090/v1.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
TOKEN_COUNT = 100
|
||||
TOKEN_DELAY_S = 0.05 # 100 * 0.05 = 5.0 s
|
||||
|
||||
logger = logging.getLogger("mock_llm")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s mock: %(message)s")
|
||||
|
||||
FILLER_TOKENS = [
|
||||
"Lorem", " ipsum", " dolor", " sit", " amet", ",", " consectetur",
|
||||
" adipiscing", " elit", ".", " Sed", " do", " eiusmod", " tempor",
|
||||
" incididunt", " ut", " labore", " et", " dolore", " magna", " aliqua",
|
||||
".", " Ut", " enim", " ad", " minim", " veniam", ",", " quis", " nostrud",
|
||||
" exercitation", " ullamco", " laboris", " nisi", " ut", " aliquip",
|
||||
" ex", " ea", " commodo", " consequat", ".", " Duis", " aute", " irure",
|
||||
" dolor", " in", " reprehenderit", " in", " voluptate", " velit",
|
||||
" esse", " cillum", " dolore", " eu", " fugiat", " nulla", " pariatur",
|
||||
".", " Excepteur", " sint", " occaecat", " cupidatat", " non", " proident",
|
||||
",", " sunt", " in", " culpa", " qui", " officia", " deserunt",
|
||||
" mollit", " anim", " id", " est", " laborum", ".", " Curabitur",
|
||||
" pretium", " tincidunt", " lacus", ".", " Nulla", " gravida", " orci",
|
||||
" a", " odio", ".", " Nullam", " varius", ",", " turpis", " et",
|
||||
" commodo", " pharetra", ",", " est", " eros", " bibendum", " elit",
|
||||
".",
|
||||
]
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def _token_stream_id() -> str:
|
||||
return f"chatcmpl-mock-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
|
||||
def _sse_chunk(completion_id: str, model: str, delta: dict, finish_reason=None) -> str:
|
||||
payload = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
}
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
|
||||
async def _stream_response(model: str, req_id: str):
|
||||
completion_id = _token_stream_id()
|
||||
yield _sse_chunk(completion_id, model, {"role": "assistant", "content": ""})
|
||||
for i, tok in enumerate(FILLER_TOKENS[:TOKEN_COUNT]):
|
||||
await asyncio.sleep(TOKEN_DELAY_S)
|
||||
yield _sse_chunk(completion_id, model, {"content": tok})
|
||||
yield _sse_chunk(completion_id, model, {}, finish_reason="stop")
|
||||
yield "data: [DONE]\n\n"
|
||||
logger.info("[%s] stream done", req_id)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request):
|
||||
body = await request.json()
|
||||
model = body.get("model", "mock")
|
||||
stream = bool(body.get("stream", False))
|
||||
req_id = uuid.uuid4().hex[:8]
|
||||
logger.info("[%s] /chat/completions stream=%s model=%s max_tokens=%s", req_id, stream, model, body.get("max_tokens"))
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
_stream_response(model, req_id),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
await asyncio.sleep(TOKEN_COUNT * TOKEN_DELAY_S)
|
||||
logger.info("[%s] non-stream done", req_id)
|
||||
text = "".join(FILLER_TOKENS[:TOKEN_COUNT])
|
||||
completion_id = _token_stream_id()
|
||||
return JSONResponse(
|
||||
{
|
||||
"id": completion_id,
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": text},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": TOKEN_COUNT,
|
||||
"total_tokens": 10 + TOKEN_COUNT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [{"id": "mock", "object": "model", "owned_by": "mock"}],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="127.0.0.1", port=8090, log_level="info")
|
||||
@@ -1,3 +1,17 @@
|
||||
"""Tests for /api/search route (application/api/answer/routes/search.py).
|
||||
|
||||
Retrieval logic lives in ``application/services/search_service.py`` and
|
||||
has its own unit tests in ``tests/services/test_search_service.py``. The
|
||||
tests below focus on what the route specifically owns:
|
||||
|
||||
* Request validation (400 for missing fields).
|
||||
* Translation of the service's ``InvalidAPIKey`` / ``SearchFailed``
|
||||
exceptions to HTTP status codes (401 / 500).
|
||||
* End-to-end happy path against a real ephemeral Postgres via
|
||||
``pg_conn``, to catch regressions in the route's wiring to the
|
||||
service and repositories.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -6,254 +20,97 @@ import pytest
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSearchResourceValidation:
|
||||
pass
|
||||
|
||||
def test_returns_error_when_question_missing(self, mock_mongo_db, flask_app):
|
||||
def test_returns_400_when_question_missing(self, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
|
||||
with flask_app.app_context():
|
||||
with flask_app.test_request_context(
|
||||
json={"api_key": "test_key"}
|
||||
):
|
||||
resource = SearchResource()
|
||||
result = resource.post()
|
||||
|
||||
with flask_app.test_request_context(json={"api_key": "test_key"}):
|
||||
result = SearchResource().post()
|
||||
assert result.status_code == 400
|
||||
assert "question" in result.json["error"]
|
||||
|
||||
def test_returns_error_when_api_key_missing(self, mock_mongo_db, flask_app):
|
||||
def test_returns_400_when_api_key_missing(self, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
|
||||
with flask_app.app_context():
|
||||
with flask_app.test_request_context(
|
||||
json={"question": "test query"}
|
||||
):
|
||||
resource = SearchResource()
|
||||
result = resource.post()
|
||||
|
||||
with flask_app.test_request_context(json={"question": "test query"}):
|
||||
result = SearchResource().post()
|
||||
assert result.status_code == 400
|
||||
assert "api_key" in result.json["error"]
|
||||
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetSourcesFromApiKey:
|
||||
pass
|
||||
class TestSearchResourceExceptionMapping:
|
||||
"""Verify the route maps service exceptions to HTTP status codes.
|
||||
|
||||
def test_returns_source_id_via_patched_method(self, mock_mongo_db, flask_app):
|
||||
"""Test that _get_sources_from_api_key can return multiple sources via patch."""
|
||||
The service function itself is patched; these tests do not care about
|
||||
the search logic — only that 401/500/200 are produced correctly from
|
||||
the three possible service outcomes.
|
||||
"""
|
||||
|
||||
def test_invalid_api_key_returns_401(self, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
from application.services.search_service import InvalidAPIKey
|
||||
|
||||
with flask_app.app_context(), flask_app.test_request_context(
|
||||
json={"question": "q", "api_key": "bad"}
|
||||
), patch(
|
||||
"application.api.answer.routes.search.search",
|
||||
side_effect=InvalidAPIKey(),
|
||||
):
|
||||
result = SearchResource().post()
|
||||
assert result.status_code == 401
|
||||
assert result.json == {"error": "Invalid API key"}
|
||||
|
||||
def test_search_failed_returns_500(self, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
from application.services.search_service import SearchFailed
|
||||
|
||||
with flask_app.app_context(), flask_app.test_request_context(
|
||||
json={"question": "q", "api_key": "k"}
|
||||
), patch(
|
||||
"application.api.answer.routes.search.search",
|
||||
side_effect=SearchFailed("boom"),
|
||||
):
|
||||
result = SearchResource().post()
|
||||
assert result.status_code == 500
|
||||
assert result.json == {"error": "Search failed"}
|
||||
|
||||
def test_happy_path_passes_service_result_through(self, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = SearchResource()
|
||||
hits = [{"text": "t", "title": "T", "source": "s"}]
|
||||
with flask_app.app_context(), flask_app.test_request_context(
|
||||
json={"question": "q", "api_key": "k", "chunks": 7}
|
||||
), patch(
|
||||
"application.api.answer.routes.search.search",
|
||||
return_value=hits,
|
||||
) as mock_search:
|
||||
result = SearchResource().post()
|
||||
assert result.status_code == 200
|
||||
assert result.json == hits
|
||||
mock_search.assert_called_once_with("k", "q", 7)
|
||||
|
||||
with patch.object(resource, "_get_sources_from_api_key", return_value=["src1", "src2"]):
|
||||
result = resource._get_sources_from_api_key("any_key")
|
||||
|
||||
assert len(result) == 2
|
||||
assert "src1" in result
|
||||
assert "src2" in result
|
||||
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSearchVectorstores:
|
||||
pass
|
||||
|
||||
def test_returns_empty_when_no_source_ids(self, mock_mongo_db, flask_app):
|
||||
def test_default_chunks_is_5(self, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = SearchResource()
|
||||
|
||||
result = resource._search_vectorstores("test query", [], 5)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_skips_empty_source_ids(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = SearchResource()
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
|
||||
) as mock_create:
|
||||
mock_vectorstore = MagicMock()
|
||||
mock_vectorstore.search.return_value = []
|
||||
mock_create.return_value = mock_vectorstore
|
||||
|
||||
result = resource._search_vectorstores("test query", ["", " "], 5)
|
||||
|
||||
mock_create.assert_not_called()
|
||||
assert result == []
|
||||
|
||||
def test_returns_search_results(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = SearchResource()
|
||||
|
||||
mock_doc = {
|
||||
"text": "Test content",
|
||||
"page_content": "Test content",
|
||||
"metadata": {
|
||||
"title": "Test Title",
|
||||
"source": "/path/to/doc",
|
||||
},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
|
||||
) as mock_create:
|
||||
mock_vectorstore = MagicMock()
|
||||
mock_vectorstore.search.return_value = [mock_doc]
|
||||
mock_create.return_value = mock_vectorstore
|
||||
|
||||
result = resource._search_vectorstores("test query", ["source_id"], 5)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["text"] == "Test content"
|
||||
assert result[0]["title"] == "Test Title"
|
||||
assert result[0]["source"] == "/path/to/doc"
|
||||
|
||||
def test_handles_langchain_document_format(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = SearchResource()
|
||||
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.page_content = "Langchain content"
|
||||
mock_doc.metadata = {"title": "LC Title", "source": "/lc/path"}
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
|
||||
) as mock_create:
|
||||
mock_vectorstore = MagicMock()
|
||||
mock_vectorstore.search.return_value = [mock_doc]
|
||||
mock_create.return_value = mock_vectorstore
|
||||
|
||||
result = resource._search_vectorstores("test query", ["source_id"], 5)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["text"] == "Langchain content"
|
||||
assert result[0]["title"] == "LC Title"
|
||||
|
||||
def test_respects_chunks_limit(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = SearchResource()
|
||||
|
||||
mock_docs = [
|
||||
{"text": f"Content {i}", "metadata": {"title": f"Title {i}"}}
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
|
||||
) as mock_create:
|
||||
mock_vectorstore = MagicMock()
|
||||
mock_vectorstore.search.return_value = mock_docs
|
||||
mock_create.return_value = mock_vectorstore
|
||||
|
||||
result = resource._search_vectorstores("test query", ["source_id"], 3)
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
def test_deduplicates_results(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = SearchResource()
|
||||
|
||||
duplicate_text = "Duplicate content " * 20
|
||||
mock_docs = [
|
||||
{"text": duplicate_text, "metadata": {"title": "Title 1"}},
|
||||
{"text": duplicate_text, "metadata": {"title": "Title 2"}},
|
||||
{"text": "Unique content", "metadata": {"title": "Title 3"}},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
|
||||
) as mock_create:
|
||||
mock_vectorstore = MagicMock()
|
||||
mock_vectorstore.search.return_value = mock_docs
|
||||
mock_create.return_value = mock_vectorstore
|
||||
|
||||
result = resource._search_vectorstores("test query", ["source_id"], 5)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
def test_handles_vectorstore_error_gracefully(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = SearchResource()
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
|
||||
) as mock_create:
|
||||
mock_create.side_effect = Exception("Vectorstore error")
|
||||
|
||||
result = resource._search_vectorstores("test query", ["source_id"], 5)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_uses_filename_as_title_fallback(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = SearchResource()
|
||||
|
||||
mock_doc = {
|
||||
"text": "Content without title",
|
||||
"metadata": {"filename": "document.pdf"},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
|
||||
) as mock_create:
|
||||
mock_vectorstore = MagicMock()
|
||||
mock_vectorstore.search.return_value = [mock_doc]
|
||||
mock_create.return_value = mock_vectorstore
|
||||
|
||||
result = resource._search_vectorstores("test query", ["source_id"], 5)
|
||||
|
||||
assert result[0]["title"] == "document.pdf"
|
||||
|
||||
def test_uses_content_snippet_as_title_last_resort(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = SearchResource()
|
||||
|
||||
mock_doc = {
|
||||
"text": "Content without any title metadata at all",
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
|
||||
) as mock_create:
|
||||
mock_vectorstore = MagicMock()
|
||||
mock_vectorstore.search.return_value = [mock_doc]
|
||||
mock_create.return_value = mock_vectorstore
|
||||
|
||||
result = resource._search_vectorstores("test query", ["source_id"], 5)
|
||||
|
||||
assert "Content without any title" in result[0]["title"]
|
||||
assert result[0]["title"].endswith("...")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSearchEndpoint:
|
||||
pass
|
||||
with flask_app.app_context(), flask_app.test_request_context(
|
||||
json={"question": "q", "api_key": "k"} # no chunks field
|
||||
), patch(
|
||||
"application.api.answer.routes.search.search",
|
||||
return_value=[],
|
||||
) as mock_search:
|
||||
SearchResource().post()
|
||||
mock_search.assert_called_once_with("k", "q", 5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real-PG tests for SearchResource.
|
||||
# End-to-end against a real ephemeral Postgres.
|
||||
#
|
||||
# These exercise the full route → service → repository → DB path, patching
|
||||
# only ``VectorCreator.create_vectorstore`` (so we don't need real embeddings
|
||||
# or a vector index). ``db_readonly`` is redirected at the *service* module
|
||||
# since that's where the import now lives.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -264,7 +121,7 @@ def _patch_search_db(conn):
|
||||
yield conn
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.search.db_readonly", _yield
|
||||
"application.services.search_service.db_readonly", _yield
|
||||
):
|
||||
yield
|
||||
|
||||
@@ -298,9 +155,7 @@ class TestSearchResourcePgConn:
|
||||
def test_search_returns_results(self, pg_conn, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.sources import (
|
||||
SourcesRepository,
|
||||
)
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
|
||||
src = SourcesRepository(pg_conn).create("src", user_id="u")
|
||||
AgentsRepository(pg_conn).create(
|
||||
@@ -315,7 +170,7 @@ class TestSearchResourcePgConn:
|
||||
]
|
||||
|
||||
with _patch_search_db(pg_conn), patch(
|
||||
"application.api.answer.routes.search.VectorCreator.create_vectorstore",
|
||||
"application.services.search_service.VectorCreator.create_vectorstore",
|
||||
return_value=fake_vs,
|
||||
), flask_app.app_context():
|
||||
with flask_app.test_request_context(
|
||||
@@ -328,9 +183,7 @@ class TestSearchResourcePgConn:
|
||||
def test_search_uses_extra_source_ids(self, pg_conn, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.sources import (
|
||||
SourcesRepository,
|
||||
)
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
|
||||
src1 = SourcesRepository(pg_conn).create("s1", user_id="u")
|
||||
src2 = SourcesRepository(pg_conn).create("s2", user_id="u")
|
||||
@@ -345,7 +198,7 @@ class TestSearchResourcePgConn:
|
||||
{"text": "one", "metadata": {"title": "A"}},
|
||||
]
|
||||
with _patch_search_db(pg_conn), patch(
|
||||
"application.api.answer.routes.search.VectorCreator.create_vectorstore",
|
||||
"application.services.search_service.VectorCreator.create_vectorstore",
|
||||
return_value=fake_vs,
|
||||
), flask_app.app_context():
|
||||
with flask_app.test_request_context(
|
||||
@@ -353,71 +206,3 @@ class TestSearchResourcePgConn:
|
||||
):
|
||||
result = SearchResource().post()
|
||||
assert result.status_code == 200
|
||||
|
||||
def test_search_exception_returns_500(self, pg_conn, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.sources import (
|
||||
SourcesRepository,
|
||||
)
|
||||
|
||||
src = SourcesRepository(pg_conn).create("src", user_id="u")
|
||||
AgentsRepository(pg_conn).create(
|
||||
"u", "a", "published",
|
||||
key="err-key",
|
||||
source_id=str(src["id"]),
|
||||
)
|
||||
|
||||
with _patch_search_db(pg_conn), patch(
|
||||
"application.api.answer.routes.search.SearchResource._get_sources_from_api_key",
|
||||
side_effect=RuntimeError("boom"),
|
||||
), flask_app.app_context():
|
||||
with flask_app.test_request_context(
|
||||
json={"question": "q", "api_key": "err-key"},
|
||||
):
|
||||
result = SearchResource().post()
|
||||
assert result.status_code == 500
|
||||
|
||||
|
||||
class TestGetSourcesFromApiKeyPg:
|
||||
def test_empty_for_unknown_key(self, pg_conn, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
|
||||
with _patch_search_db(pg_conn), flask_app.app_context():
|
||||
got = SearchResource()._get_sources_from_api_key("nope")
|
||||
assert got == []
|
||||
|
||||
def test_returns_extra_source_ids(self, pg_conn, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.sources import (
|
||||
SourcesRepository,
|
||||
)
|
||||
|
||||
src = SourcesRepository(pg_conn).create("s", user_id="u")
|
||||
AgentsRepository(pg_conn).create(
|
||||
"u", "a", "published",
|
||||
key="sources-key",
|
||||
extra_source_ids=[str(src["id"])],
|
||||
)
|
||||
with _patch_search_db(pg_conn), flask_app.app_context():
|
||||
got = SearchResource()._get_sources_from_api_key("sources-key")
|
||||
assert got == [str(src["id"])]
|
||||
|
||||
def test_falls_back_to_single_source(self, pg_conn, flask_app):
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.sources import (
|
||||
SourcesRepository,
|
||||
)
|
||||
|
||||
src = SourcesRepository(pg_conn).create("s", user_id="u")
|
||||
AgentsRepository(pg_conn).create(
|
||||
"u", "a", "published",
|
||||
key="single-key",
|
||||
source_id=str(src["id"]),
|
||||
)
|
||||
with _patch_search_db(pg_conn), flask_app.app_context():
|
||||
got = SearchResource()._get_sources_from_api_key("single-key")
|
||||
assert got == [str(src["id"])]
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.23
|
||||
pytest-cov>=4.1.0
|
||||
coverage>=7.4.0
|
||||
pytest-postgresql>=6.0.0
|
||||
|
||||
0
tests/services/__init__.py
Normal file
0
tests/services/__init__.py
Normal file
134
tests/services/test_mcp_server.py
Normal file
134
tests/services/test_mcp_server.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Tests for application/mcp_server.py.
|
||||
|
||||
The server module exposes one FastMCP tool, ``search_docs``, that reads
|
||||
the caller's ``Authorization: Bearer <key>`` header via
|
||||
``get_http_headers()`` and delegates to
|
||||
``application.services.search_service.search``. These tests exercise
|
||||
the tool directly by patching ``get_http_headers`` and ``search``; the
|
||||
full HTTP-layer plumbing (mount, lifespan, session handshake) is
|
||||
covered by ``tests/test_asgi.py``.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSearchDocsTool:
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_bearer_raises_permission_error(self):
|
||||
from application.mcp_server import search_docs
|
||||
|
||||
with patch(
|
||||
"application.mcp_server.get_http_headers", return_value={}
|
||||
):
|
||||
with pytest.raises(PermissionError):
|
||||
await search_docs(query="hi")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_bearer_header_raises_permission_error(self):
|
||||
from application.mcp_server import search_docs
|
||||
|
||||
with patch(
|
||||
"application.mcp_server.get_http_headers",
|
||||
return_value={"authorization": "Basic dXNlcjpwYXNz"},
|
||||
):
|
||||
with pytest.raises(PermissionError):
|
||||
await search_docs(query="hi")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blank_bearer_token_raises_permission_error(self):
|
||||
from application.mcp_server import search_docs
|
||||
|
||||
with patch(
|
||||
"application.mcp_server.get_http_headers",
|
||||
return_value={"authorization": "Bearer "},
|
||||
):
|
||||
with pytest.raises(PermissionError):
|
||||
await search_docs(query="hi")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_api_key_raises_permission_error(self):
|
||||
from application.mcp_server import search_docs
|
||||
from application.services.search_service import InvalidAPIKey
|
||||
|
||||
with (
|
||||
patch(
|
||||
"application.mcp_server.get_http_headers",
|
||||
return_value={"authorization": "Bearer bogus"},
|
||||
),
|
||||
patch(
|
||||
"application.mcp_server.search", side_effect=InvalidAPIKey()
|
||||
),
|
||||
):
|
||||
with pytest.raises(PermissionError):
|
||||
await search_docs(query="hi")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_failed_bubbles_up(self):
|
||||
from application.mcp_server import search_docs
|
||||
from application.services.search_service import SearchFailed
|
||||
|
||||
with (
|
||||
patch(
|
||||
"application.mcp_server.get_http_headers",
|
||||
return_value={"authorization": "Bearer k"},
|
||||
),
|
||||
patch(
|
||||
"application.mcp_server.search",
|
||||
side_effect=SearchFailed("boom"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(SearchFailed):
|
||||
await search_docs(query="hi")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_happy_path_passes_args_and_returns_hits(self):
|
||||
from application.mcp_server import search_docs
|
||||
|
||||
hits = [{"text": "t", "title": "T", "source": "s"}]
|
||||
with (
|
||||
patch(
|
||||
"application.mcp_server.get_http_headers",
|
||||
return_value={"authorization": "Bearer the-key"},
|
||||
),
|
||||
patch(
|
||||
"application.mcp_server.search", return_value=hits
|
||||
) as mock_search,
|
||||
):
|
||||
out = await search_docs(query="q", chunks=7)
|
||||
assert out == hits
|
||||
mock_search.assert_called_once_with("the-key", "q", 7)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_chunks_is_5(self):
|
||||
from application.mcp_server import search_docs
|
||||
|
||||
with (
|
||||
patch(
|
||||
"application.mcp_server.get_http_headers",
|
||||
return_value={"authorization": "Bearer k"},
|
||||
),
|
||||
patch(
|
||||
"application.mcp_server.search", return_value=[]
|
||||
) as mock_search,
|
||||
):
|
||||
await search_docs(query="q")
|
||||
mock_search.assert_called_once_with("k", "q", 5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bearer_scheme_case_insensitive(self):
|
||||
from application.mcp_server import search_docs
|
||||
|
||||
with (
|
||||
patch(
|
||||
"application.mcp_server.get_http_headers",
|
||||
return_value={"authorization": "bearer lowercase-scheme"},
|
||||
),
|
||||
patch(
|
||||
"application.mcp_server.search", return_value=[]
|
||||
) as mock_search,
|
||||
):
|
||||
await search_docs(query="q")
|
||||
mock_search.assert_called_once_with("lowercase-scheme", "q", 5)
|
||||
230
tests/services/test_search_service.py
Normal file
230
tests/services/test_search_service.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""Unit tests for application/services/search_service.py.
|
||||
|
||||
Tests exercise the service function in isolation — AgentsRepository is
|
||||
stubbed via a patched ``db_readonly`` context manager, and
|
||||
``VectorCreator.create_vectorstore`` is patched to return a fake
|
||||
vectorstore. No Flask app context, no real DB, no real embeddings.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.services.search_service import (
|
||||
InvalidAPIKey,
|
||||
SearchFailed,
|
||||
_collect_source_ids,
|
||||
search,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _fake_db_readonly(agent_data):
|
||||
"""Patch ``db_readonly`` so ``AgentsRepository.find_by_key`` returns ``agent_data``."""
|
||||
agents_repo = MagicMock()
|
||||
agents_repo.find_by_key.return_value = agent_data
|
||||
|
||||
@contextmanager
|
||||
def _yield_conn():
|
||||
yield MagicMock()
|
||||
|
||||
with patch(
|
||||
"application.services.search_service.db_readonly", _yield_conn
|
||||
), patch(
|
||||
"application.services.search_service.AgentsRepository",
|
||||
return_value=agents_repo,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCollectSourceIds:
|
||||
def test_empty_when_no_sources(self):
|
||||
assert _collect_source_ids({}) == []
|
||||
|
||||
def test_returns_extra_source_ids(self):
|
||||
agent = {"extra_source_ids": ["s1", "s2"], "source_id": "legacy"}
|
||||
assert _collect_source_ids(agent) == ["s1", "s2"]
|
||||
|
||||
def test_falls_back_to_single_source_id(self):
|
||||
agent = {"extra_source_ids": [], "source_id": "s1"}
|
||||
assert _collect_source_ids(agent) == ["s1"]
|
||||
|
||||
def test_skips_empty_entries_in_extra(self):
|
||||
agent = {"extra_source_ids": ["", None, "s1"], "source_id": "fallback"}
|
||||
assert _collect_source_ids(agent) == ["s1"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSearchInvalidAPIKey:
|
||||
def test_raises_when_key_unknown(self):
|
||||
with _fake_db_readonly(None):
|
||||
with pytest.raises(InvalidAPIKey):
|
||||
search("does-not-exist", "hello", 5)
|
||||
|
||||
def test_raises_search_failed_on_db_error(self):
|
||||
@contextmanager
|
||||
def _yield_conn():
|
||||
yield MagicMock()
|
||||
|
||||
agents_repo = MagicMock()
|
||||
agents_repo.find_by_key.side_effect = RuntimeError("db down")
|
||||
|
||||
with patch(
|
||||
"application.services.search_service.db_readonly", _yield_conn
|
||||
), patch(
|
||||
"application.services.search_service.AgentsRepository",
|
||||
return_value=agents_repo,
|
||||
):
|
||||
with pytest.raises(SearchFailed):
|
||||
search("any-key", "hello", 5)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSearchEmptyWhenNoSources:
|
||||
def test_returns_empty_when_agent_has_no_sources(self):
|
||||
with _fake_db_readonly({"extra_source_ids": [], "source_id": None}):
|
||||
assert search("k", "q", 5) == []
|
||||
|
||||
def test_returns_empty_for_zero_chunks_without_db_lookup(self):
|
||||
with patch("application.services.search_service.db_readonly") as mock_db:
|
||||
assert search("k", "q", 0) == []
|
||||
mock_db.assert_not_called()
|
||||
|
||||
def test_returns_empty_for_negative_chunks_without_db_lookup(self):
|
||||
with patch("application.services.search_service.db_readonly") as mock_db:
|
||||
assert search("k", "q", -1) == []
|
||||
mock_db.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSearchResults:
|
||||
def test_returns_hit_shape(self):
|
||||
agent = {"source_id": "src-1", "extra_source_ids": []}
|
||||
fake_vs = MagicMock()
|
||||
fake_vs.search.return_value = [
|
||||
{
|
||||
"text": "Test content",
|
||||
"metadata": {"title": "Test Title", "source": "/path/to/doc"},
|
||||
}
|
||||
]
|
||||
with _fake_db_readonly(agent), patch(
|
||||
"application.services.search_service.VectorCreator.create_vectorstore",
|
||||
return_value=fake_vs,
|
||||
):
|
||||
results = search("k", "q", 5)
|
||||
assert results == [
|
||||
{"text": "Test content", "title": "Test Title", "source": "/path/to/doc"}
|
||||
]
|
||||
|
||||
def test_handles_langchain_document_format(self):
|
||||
agent = {"source_id": "src-1", "extra_source_ids": []}
|
||||
lc_doc = MagicMock()
|
||||
lc_doc.page_content = "Langchain content"
|
||||
lc_doc.metadata = {"title": "LC Title", "source": "/lc/path"}
|
||||
|
||||
fake_vs = MagicMock()
|
||||
fake_vs.search.return_value = [lc_doc]
|
||||
|
||||
with _fake_db_readonly(agent), patch(
|
||||
"application.services.search_service.VectorCreator.create_vectorstore",
|
||||
return_value=fake_vs,
|
||||
):
|
||||
results = search("k", "q", 5)
|
||||
assert len(results) == 1
|
||||
assert results[0]["text"] == "Langchain content"
|
||||
assert results[0]["title"] == "LC Title"
|
||||
|
||||
def test_respects_chunks_cap(self):
|
||||
agent = {"source_id": "src-1", "extra_source_ids": []}
|
||||
docs = [
|
||||
{"text": f"Content {i}", "metadata": {"title": f"T{i}"}}
|
||||
for i in range(10)
|
||||
]
|
||||
fake_vs = MagicMock()
|
||||
fake_vs.search.return_value = docs
|
||||
|
||||
with _fake_db_readonly(agent), patch(
|
||||
"application.services.search_service.VectorCreator.create_vectorstore",
|
||||
return_value=fake_vs,
|
||||
):
|
||||
results = search("k", "q", 3)
|
||||
assert len(results) == 3
|
||||
|
||||
def test_deduplicates_results_by_content_prefix(self):
|
||||
agent = {"source_id": "src-1", "extra_source_ids": []}
|
||||
dup_text = "Duplicate content " * 20
|
||||
docs = [
|
||||
{"text": dup_text, "metadata": {"title": "T1"}},
|
||||
{"text": dup_text, "metadata": {"title": "T2"}},
|
||||
{"text": "Unique content", "metadata": {"title": "T3"}},
|
||||
]
|
||||
fake_vs = MagicMock()
|
||||
fake_vs.search.return_value = docs
|
||||
|
||||
with _fake_db_readonly(agent), patch(
|
||||
"application.services.search_service.VectorCreator.create_vectorstore",
|
||||
return_value=fake_vs,
|
||||
):
|
||||
results = search("k", "q", 5)
|
||||
assert len(results) == 2
|
||||
|
||||
def test_skips_broken_source_and_returns_from_healthy_ones(self):
|
||||
# Two sources — the first raises, the second returns a doc. The
|
||||
# caller should still get the healthy source's result.
|
||||
agent = {"extra_source_ids": ["broken", "ok"], "source_id": None}
|
||||
healthy_vs = MagicMock()
|
||||
healthy_vs.search.return_value = [
|
||||
{"text": "ok content", "metadata": {"title": "Ok"}}
|
||||
]
|
||||
|
||||
def create_vs(store, source_id, key):
|
||||
if source_id == "broken":
|
||||
raise RuntimeError("vector index missing")
|
||||
return healthy_vs
|
||||
|
||||
with _fake_db_readonly(agent), patch(
|
||||
"application.services.search_service.VectorCreator.create_vectorstore",
|
||||
side_effect=create_vs,
|
||||
):
|
||||
results = search("k", "q", 5)
|
||||
assert len(results) == 1
|
||||
assert results[0]["text"] == "ok content"
|
||||
|
||||
def test_uses_filename_when_title_missing(self):
|
||||
agent = {"source_id": "src-1", "extra_source_ids": []}
|
||||
fake_vs = MagicMock()
|
||||
fake_vs.search.return_value = [
|
||||
{"text": "body", "metadata": {"filename": "document.pdf"}}
|
||||
]
|
||||
with _fake_db_readonly(agent), patch(
|
||||
"application.services.search_service.VectorCreator.create_vectorstore",
|
||||
return_value=fake_vs,
|
||||
):
|
||||
results = search("k", "q", 5)
|
||||
assert results[0]["title"] == "document.pdf"
|
||||
|
||||
def test_uses_content_snippet_as_title_last_resort(self):
|
||||
agent = {"source_id": "src-1", "extra_source_ids": []}
|
||||
fake_vs = MagicMock()
|
||||
fake_vs.search.return_value = [
|
||||
{"text": "Content without any title metadata at all", "metadata": {}}
|
||||
]
|
||||
with _fake_db_readonly(agent), patch(
|
||||
"application.services.search_service.VectorCreator.create_vectorstore",
|
||||
return_value=fake_vs,
|
||||
):
|
||||
results = search("k", "q", 5)
|
||||
assert results[0]["title"].endswith("...")
|
||||
assert "Content without any title" in results[0]["title"]
|
||||
|
||||
def test_skips_empty_source_ids(self):
|
||||
# ``source_id=" "`` only — after strip() this leaves no real source.
|
||||
agent = {"extra_source_ids": [" ", ""], "source_id": None}
|
||||
with _fake_db_readonly(agent), patch(
|
||||
"application.services.search_service.VectorCreator.create_vectorstore"
|
||||
) as mock_create:
|
||||
results = search("k", "q", 5)
|
||||
mock_create.assert_not_called()
|
||||
assert results == []
|
||||
@@ -105,11 +105,26 @@ class TestAuthenticateRequest:
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestAfterRequest:
|
||||
class TestFlaskCors:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cors_headers(self, client):
|
||||
response = client.get("/api/health")
|
||||
assert response.headers.get("Access-Control-Allow-Origin") == "*"
|
||||
assert "Content-Type" in response.headers.get("Access-Control-Allow-Headers", "")
|
||||
assert "GET" in response.headers.get("Access-Control-Allow-Methods", "")
|
||||
def test_cors_headers_on_flask_route(self, client):
|
||||
response = client.get("/api/health", headers={"Origin": "http://localhost:5173"})
|
||||
assert response.headers["Access-Control-Allow-Origin"] == "*"
|
||||
assert response.headers["Access-Control-Allow-Headers"] == "Content-Type, Authorization"
|
||||
assert response.headers["Access-Control-Allow-Methods"] == "GET, POST, PUT, DELETE, OPTIONS"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cors_headers_on_flask_preflight(self, client):
|
||||
response = client.options(
|
||||
"/api/health",
|
||||
headers={
|
||||
"Origin": "http://localhost:5173",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
"Access-Control-Request-Headers": "Content-Type",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["Access-Control-Allow-Origin"] == "*"
|
||||
assert response.headers["Access-Control-Allow-Headers"] == "Content-Type, Authorization"
|
||||
assert response.headers["Access-Control-Allow-Methods"] == "GET, POST, PUT, DELETE, OPTIONS"
|
||||
|
||||
136
tests/test_asgi.py
Normal file
136
tests/test_asgi.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Smoke tests for application/asgi.py.
|
||||
|
||||
The goal isn't to re-test Flask or FastMCP internals — it's to catch
|
||||
regressions in the wiring: mounts resolve, CORS headers emit, lifespan
|
||||
runs (without it, the /mcp session manager raises "Task group is not
|
||||
initialized"), routing to ``/`` vs ``/mcp`` doesn't cross paths.
|
||||
|
||||
Uses ``starlette.testclient.TestClient`` because it boots the ASGI app
|
||||
end-to-end and handles the lifespan protocol automatically — ``httpx``
|
||||
alone does not run lifespan events, which would mask the exact kind of
|
||||
misconfiguration this test suite exists to catch.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_asgi_app_imports():
|
||||
from application.asgi import asgi_app
|
||||
|
||||
assert asgi_app is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_flask_route_served_through_starlette_mount():
|
||||
"""GET /api/health should reach the Flask app via a2wsgi and return 200."""
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from application.asgi import asgi_app
|
||||
|
||||
with TestClient(asgi_app) as client:
|
||||
r = client.get("/api/health")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_mcp_endpoint_mounted_and_lifespan_runs():
|
||||
"""/mcp must be reachable AND the FastMCP session manager must start.
|
||||
|
||||
Without ``lifespan=mcp_app.lifespan`` on the outer Starlette app,
|
||||
every /mcp request raises ``RuntimeError: Task group is not
|
||||
initialized``. Hitting the endpoint under a real lifespan-aware
|
||||
client catches that.
|
||||
"""
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from application.asgi import asgi_app
|
||||
|
||||
with TestClient(asgi_app) as client:
|
||||
# Minimal MCP initialize request. Doesn't need to succeed — we
|
||||
# just need a non-404, non-500-with-RuntimeError response to
|
||||
# confirm the mount + lifespan are both wired.
|
||||
r = client.post(
|
||||
"/mcp/",
|
||||
headers={
|
||||
"Origin": "http://example.com",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "pytest", "version": "0"},
|
||||
},
|
||||
},
|
||||
)
|
||||
assert r.status_code != 404, f"/mcp mount unreachable: {r.status_code}"
|
||||
# A successful initialize returns 200 with a Mcp-Session-Id header.
|
||||
assert r.status_code == 200
|
||||
assert "mcp-session-id" in {k.lower() for k in r.headers.keys()}
|
||||
assert r.headers.get("access-control-expose-headers") == "Mcp-Session-Id"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cors_headers_on_flask_route():
|
||||
"""CORS middleware should emit allow-origin on actual (non-preflight) requests.
|
||||
|
||||
``allow_origins=["*"]`` → header value is literal ``*`` (not an echo).
|
||||
"""
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from application.asgi import asgi_app
|
||||
|
||||
with TestClient(asgi_app) as client:
|
||||
r = client.get("/api/health", headers={"Origin": "http://example.com"})
|
||||
assert r.status_code == 200
|
||||
assert r.headers.get("access-control-allow-origin") == "*"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cors_preflight_on_flask_route():
|
||||
"""OPTIONS preflight on a Flask route should be handled by Starlette CORSMiddleware."""
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from application.asgi import asgi_app
|
||||
|
||||
with TestClient(asgi_app) as client:
|
||||
r = client.options(
|
||||
"/api/health",
|
||||
headers={
|
||||
"Origin": "http://example.com",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
"Access-Control-Request-Headers": "Content-Type",
|
||||
},
|
||||
)
|
||||
assert r.status_code in (200, 204)
|
||||
assert r.headers.get("access-control-allow-origin") == "*"
|
||||
assert "GET" in r.headers.get("access-control-allow-methods", "")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cors_preflight_on_mcp_route():
|
||||
"""Browser clients hitting /mcp should be allowed to send session headers."""
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from application.asgi import asgi_app
|
||||
|
||||
with TestClient(asgi_app) as client:
|
||||
r = client.options(
|
||||
"/mcp/",
|
||||
headers={
|
||||
"Origin": "http://example.com",
|
||||
"Access-Control-Request-Method": "POST",
|
||||
"Access-Control-Request-Headers": (
|
||||
"Authorization, Content-Type, Mcp-Session-Id"
|
||||
),
|
||||
},
|
||||
)
|
||||
assert r.status_code in (200, 204)
|
||||
assert r.headers.get("access-control-allow-origin") == "*"
|
||||
assert "Mcp-Session-Id" in r.headers.get("access-control-allow-headers", "")
|
||||
Reference in New Issue
Block a user