feat: asgi and search service

This commit is contained in:
Alex
2026-04-22 09:12:29 +01:00
parent d4b1c1fd81
commit a5153d5212
9 changed files with 612 additions and 426 deletions

View File

@@ -1,21 +1,21 @@
import logging
from typing import Any, Dict, List
from flask import make_response, request
from flask_restx import fields, Resource
from application.api.answer.routes.base import answer_ns
from application.core.settings import settings
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.session import db_readonly
from application.vectorstore.vector_creator import VectorCreator
from application.services.search_service import (
InvalidAPIKey,
SearchFailed,
search,
)
logger = logging.getLogger(__name__)
@answer_ns.route("/api/search")
class SearchResource(Resource):
"""Fast search endpoint for retrieving relevant documents"""
"""Fast search endpoint for retrieving relevant documents."""
search_model = answer_ns.model(
"SearchModel",
@@ -32,102 +32,10 @@ class SearchResource(Resource):
},
)
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
"""Get source IDs connected to the API key/agent."""
with db_readonly() as conn:
agent_data = AgentsRepository(conn).find_by_key(api_key)
if not agent_data:
return []
source_ids: List[str] = []
# extra_source_ids is a PG ARRAY(UUID) of source UUIDs.
extra = agent_data.get("extra_source_ids") or []
for src in extra:
if src:
source_ids.append(str(src))
if not source_ids:
single = agent_data.get("source_id")
if single:
source_ids.append(str(single))
return source_ids
def _search_vectorstores(
self, query: str, source_ids: List[str], chunks: int
) -> List[Dict[str, Any]]:
"""Search across vectorstores and return results"""
if not source_ids:
return []
results = []
chunks_per_source = max(1, chunks // len(source_ids))
seen_texts = set()
for source_id in source_ids:
if not source_id or not source_id.strip():
continue
try:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
)
docs = docsearch.search(query, k=chunks_per_source * 2)
for doc in docs:
if len(results) >= chunks:
break
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
page_content = doc.page_content
metadata = doc.metadata
else:
page_content = doc.get("text", doc.get("page_content", ""))
metadata = doc.get("metadata", {})
# Skip duplicates
text_hash = hash(page_content[:200])
if text_hash in seen_texts:
continue
seen_texts.add(text_hash)
title = metadata.get(
"title", metadata.get("post_title", "")
)
if not isinstance(title, str):
title = str(title) if title else ""
# Clean up title
if title:
title = title.split("/")[-1]
else:
# Use filename or first part of content as title
title = metadata.get("filename", page_content[:50] + "...")
source = metadata.get("source", source_id)
results.append({
"text": page_content,
"title": title,
"source": source,
})
if len(results) >= chunks:
break
except Exception as e:
logger.error(
f"Error searching vectorstore {source_id}: {e}",
exc_info=True,
)
continue
return results[:chunks]
@answer_ns.expect(search_model)
@answer_ns.doc(description="Search for relevant documents based on query")
def post(self):
data = request.get_json()
data = request.get_json() or {}
question = data.get("question")
api_key = data.get("api_key")
@@ -135,32 +43,13 @@ class SearchResource(Resource):
if not question:
return make_response({"error": "question is required"}, 400)
if not api_key:
return make_response({"error": "api_key is required"}, 400)
# Validate API key
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
if not agent:
return make_response({"error": "Invalid API key"}, 401)
try:
# Get sources connected to this API key
source_ids = self._get_sources_from_api_key(api_key)
if not source_ids:
return make_response([], 200)
# Perform search
results = self._search_vectorstores(question, source_ids, chunks)
return make_response(results, 200)
except Exception as e:
logger.error(
f"/api/search - error: {str(e)}",
extra={"error": str(e)},
exc_info=True,
)
return make_response(search(api_key, question, chunks), 200)
except InvalidAPIKey:
return make_response({"error": "Invalid API key"}, 401)
except SearchFailed:
logger.exception("/api/search failed")
return make_response({"error": "Search failed"}, 500)

5
application/asgi.py Normal file
View File

@@ -0,0 +1,5 @@
from a2wsgi import WSGIMiddleware
from application.app import app as flask_app
asgi_app = WSGIMiddleware(flask_app)

View File

@@ -14,7 +14,7 @@ docx2txt==0.9
ddgs>=8.0.0
fast-ebook
elevenlabs==2.43.0
Flask==3.1.3
Flask==3.1.1
faiss-cpu==1.13.2
fastmcp==3.2.4
flask-restx==1.3.2

View File

View File

@@ -0,0 +1,150 @@
"""Shared retrieval service used by the HTTP search route and the MCP tool.
Flask-free. Raises domain exceptions (``InvalidAPIKey``, ``SearchFailed``)
that callers translate into their own wire protocol (HTTP status codes,
MCP error responses, etc.).
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List
from application.core.settings import settings
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.session import db_readonly
from application.vectorstore.vector_creator import VectorCreator
logger = logging.getLogger(__name__)
class InvalidAPIKey(Exception):
"""The supplied ``api_key`` does not resolve to an agent."""
class SearchFailed(Exception):
"""Unexpected error during retrieval (e.g. DB outage). Caller maps to 5xx."""
def _collect_source_ids(agent: Dict[str, Any]) -> List[str]:
"""Extract the ordered list of source UUIDs to search.
Prefers ``extra_source_ids`` (PG ARRAY(UUID) of multi-source agents);
falls back to the legacy single ``source_id`` field.
"""
source_ids: List[str] = []
extra = agent.get("extra_source_ids") or []
for src in extra:
if src:
source_ids.append(str(src))
if not source_ids:
single = agent.get("source_id")
if single:
source_ids.append(str(single))
return source_ids
def _search_sources(
query: str, source_ids: List[str], chunks: int
) -> List[Dict[str, Any]]:
"""Search across each source's vectorstore and return up to ``chunks`` hits.
Per-source errors are logged and skipped so one broken index doesn't
take down the whole search. Results are de-duplicated by content hash.
"""
if not source_ids:
return []
results: List[Dict[str, Any]] = []
chunks_per_source = max(1, chunks // len(source_ids))
seen_texts: set[int] = set()
for source_id in source_ids:
if not source_id or not source_id.strip():
continue
try:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
)
docs = docsearch.search(query, k=chunks_per_source * 2)
for doc in docs:
if len(results) >= chunks:
break
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
page_content = doc.page_content
metadata = doc.metadata
else:
page_content = doc.get("text", doc.get("page_content", ""))
metadata = doc.get("metadata", {})
text_hash = hash(page_content[:200])
if text_hash in seen_texts:
continue
seen_texts.add(text_hash)
title = metadata.get("title", metadata.get("post_title", ""))
if not isinstance(title, str):
title = str(title) if title else ""
if title:
title = title.split("/")[-1]
else:
title = metadata.get("filename", page_content[:50] + "...")
source = metadata.get("source", source_id)
results.append(
{
"text": page_content,
"title": title,
"source": source,
}
)
if len(results) >= chunks:
break
except Exception as e:
logger.error(
f"Error searching vectorstore {source_id}: {e}",
exc_info=True,
)
continue
return results[:chunks]
def search(api_key: str, query: str, chunks: int = 5) -> List[Dict[str, Any]]:
"""Resolve an agent by API key and search its sources.
Args:
api_key: Agent API key (the opaque string stored on
``agents.key`` in Postgres).
query: Free-text search query.
chunks: Max number of hits to return.
Returns:
List of hit dicts with ``text``, ``title``, ``source`` keys.
Empty list if the agent has no sources configured.
Raises:
InvalidAPIKey: if ``api_key`` does not resolve to an agent.
SearchFailed: on unexpected DB / infrastructure errors.
"""
try:
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
except Exception as e:
raise SearchFailed("agent lookup failed") from e
if not agent:
raise InvalidAPIKey()
source_ids = _collect_source_ids(agent)
if not source_ids:
return []
return _search_sources(query, source_ids, chunks)

137
scripts/mock_llm.py Normal file
View File

@@ -0,0 +1,137 @@
"""Mock OpenAI-compatible LLM server for benchmarking.
Fixed 5-second generation (100 tokens × 50 ms/token). No auth. Emits SSE
chunks in OpenAI's chat.completions streaming format, or a single response
when stream=false. Run on 127.0.0.1:8090 — point DocsGPT at it via
OPENAI_BASE_URL=http://127.0.0.1:8090/v1.
"""
import asyncio
import json
import logging
import time
import uuid
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
TOKEN_COUNT = 100
TOKEN_DELAY_S = 0.05 # 100 * 0.05 = 5.0 s
logger = logging.getLogger("mock_llm")
logging.basicConfig(level=logging.INFO, format="%(asctime)s mock: %(message)s")
FILLER_TOKENS = [
"Lorem", " ipsum", " dolor", " sit", " amet", ",", " consectetur",
" adipiscing", " elit", ".", " Sed", " do", " eiusmod", " tempor",
" incididunt", " ut", " labore", " et", " dolore", " magna", " aliqua",
".", " Ut", " enim", " ad", " minim", " veniam", ",", " quis", " nostrud",
" exercitation", " ullamco", " laboris", " nisi", " ut", " aliquip",
" ex", " ea", " commodo", " consequat", ".", " Duis", " aute", " irure",
" dolor", " in", " reprehenderit", " in", " voluptate", " velit",
" esse", " cillum", " dolore", " eu", " fugiat", " nulla", " pariatur",
".", " Excepteur", " sint", " occaecat", " cupidatat", " non", " proident",
",", " sunt", " in", " culpa", " qui", " officia", " deserunt",
" mollit", " anim", " id", " est", " laborum", ".", " Curabitur",
" pretium", " tincidunt", " lacus", ".", " Nulla", " gravida", " orci",
" a", " odio", ".", " Nullam", " varius", ",", " turpis", " et",
" commodo", " pharetra", ",", " est", " eros", " bibendum", " elit",
".",
]
app = FastAPI()
def _token_stream_id() -> str:
return f"chatcmpl-mock-{uuid.uuid4().hex[:12]}"
def _sse_chunk(completion_id: str, model: str, delta: dict, finish_reason=None) -> str:
payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": delta,
"finish_reason": finish_reason,
}
],
}
return f"data: {json.dumps(payload)}\n\n"
async def _stream_response(model: str, req_id: str):
completion_id = _token_stream_id()
yield _sse_chunk(completion_id, model, {"role": "assistant", "content": ""})
for i, tok in enumerate(FILLER_TOKENS[:TOKEN_COUNT]):
await asyncio.sleep(TOKEN_DELAY_S)
yield _sse_chunk(completion_id, model, {"content": tok})
yield _sse_chunk(completion_id, model, {}, finish_reason="stop")
yield "data: [DONE]\n\n"
logger.info("[%s] stream done", req_id)
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
body = await request.json()
model = body.get("model", "mock")
stream = bool(body.get("stream", False))
req_id = uuid.uuid4().hex[:8]
logger.info("[%s] /chat/completions stream=%s model=%s max_tokens=%s", req_id, stream, model, body.get("max_tokens"))
if stream:
return StreamingResponse(
_stream_response(model, req_id),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
},
)
await asyncio.sleep(TOKEN_COUNT * TOKEN_DELAY_S)
logger.info("[%s] non-stream done", req_id)
text = "".join(FILLER_TOKENS[:TOKEN_COUNT])
completion_id = _token_stream_id()
return JSONResponse(
{
"id": completion_id,
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": TOKEN_COUNT,
"total_tokens": 10 + TOKEN_COUNT,
},
}
)
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [{"id": "mock", "object": "model", "owned_by": "mock"}],
}
@app.get("/health")
async def health():
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8090, log_level="info")

View File

@@ -1,3 +1,17 @@
"""Tests for /api/search route (application/api/answer/routes/search.py).
Retrieval logic lives in ``application/services/search_service.py`` and
has its own unit tests in ``tests/services/test_search_service.py``. The
tests below focus on what the route specifically owns:
* Request validation (400 for missing fields).
* Translation of the service's ``InvalidAPIKey`` / ``SearchFailed``
exceptions to HTTP status codes (401 / 500).
* End-to-end happy path against a real ephemeral Postgres via
``pg_conn``, to catch regressions in the route's wiring to the
service and repositories.
"""
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
@@ -6,254 +20,97 @@ import pytest
@pytest.mark.unit
class TestSearchResourceValidation:
pass
def test_returns_error_when_question_missing(self, mock_mongo_db, flask_app):
def test_returns_400_when_question_missing(self, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
with flask_app.test_request_context(
json={"api_key": "test_key"}
):
resource = SearchResource()
result = resource.post()
with flask_app.test_request_context(json={"api_key": "test_key"}):
result = SearchResource().post()
assert result.status_code == 400
assert "question" in result.json["error"]
def test_returns_error_when_api_key_missing(self, mock_mongo_db, flask_app):
def test_returns_400_when_api_key_missing(self, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
with flask_app.test_request_context(
json={"question": "test query"}
):
resource = SearchResource()
result = resource.post()
with flask_app.test_request_context(json={"question": "test query"}):
result = SearchResource().post()
assert result.status_code == 400
assert "api_key" in result.json["error"]
@pytest.mark.unit
class TestGetSourcesFromApiKey:
pass
class TestSearchResourceExceptionMapping:
"""Verify the route maps service exceptions to HTTP status codes.
def test_returns_source_id_via_patched_method(self, mock_mongo_db, flask_app):
"""Test that _get_sources_from_api_key can return multiple sources via patch."""
The service function itself is patched; these tests do not care about
the search logic — only that 401/500/200 are produced correctly from
the three possible service outcomes.
"""
def test_invalid_api_key_returns_401(self, flask_app):
from application.api.answer.routes.search import SearchResource
from application.services.search_service import InvalidAPIKey
with flask_app.app_context(), flask_app.test_request_context(
json={"question": "q", "api_key": "bad"}
), patch(
"application.api.answer.routes.search.search",
side_effect=InvalidAPIKey(),
):
result = SearchResource().post()
assert result.status_code == 401
assert result.json == {"error": "Invalid API key"}
def test_search_failed_returns_500(self, flask_app):
from application.api.answer.routes.search import SearchResource
from application.services.search_service import SearchFailed
with flask_app.app_context(), flask_app.test_request_context(
json={"question": "q", "api_key": "k"}
), patch(
"application.api.answer.routes.search.search",
side_effect=SearchFailed("boom"),
):
result = SearchResource().post()
assert result.status_code == 500
assert result.json == {"error": "Search failed"}
def test_happy_path_passes_service_result_through(self, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
hits = [{"text": "t", "title": "T", "source": "s"}]
with flask_app.app_context(), flask_app.test_request_context(
json={"question": "q", "api_key": "k", "chunks": 7}
), patch(
"application.api.answer.routes.search.search",
return_value=hits,
) as mock_search:
result = SearchResource().post()
assert result.status_code == 200
assert result.json == hits
mock_search.assert_called_once_with("k", "q", 7)
with patch.object(resource, "_get_sources_from_api_key", return_value=["src1", "src2"]):
result = resource._get_sources_from_api_key("any_key")
assert len(result) == 2
assert "src1" in result
assert "src2" in result
@pytest.mark.unit
class TestSearchVectorstores:
pass
def test_returns_empty_when_no_source_ids(self, mock_mongo_db, flask_app):
def test_default_chunks_is_5(self, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
result = resource._search_vectorstores("test query", [], 5)
assert result == []
def test_skips_empty_source_ids(self, mock_mongo_db, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
with patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
) as mock_create:
mock_vectorstore = MagicMock()
mock_vectorstore.search.return_value = []
mock_create.return_value = mock_vectorstore
result = resource._search_vectorstores("test query", ["", " "], 5)
mock_create.assert_not_called()
assert result == []
def test_returns_search_results(self, mock_mongo_db, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
mock_doc = {
"text": "Test content",
"page_content": "Test content",
"metadata": {
"title": "Test Title",
"source": "/path/to/doc",
},
}
with patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
) as mock_create:
mock_vectorstore = MagicMock()
mock_vectorstore.search.return_value = [mock_doc]
mock_create.return_value = mock_vectorstore
result = resource._search_vectorstores("test query", ["source_id"], 5)
assert len(result) == 1
assert result[0]["text"] == "Test content"
assert result[0]["title"] == "Test Title"
assert result[0]["source"] == "/path/to/doc"
def test_handles_langchain_document_format(self, mock_mongo_db, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
mock_doc = MagicMock()
mock_doc.page_content = "Langchain content"
mock_doc.metadata = {"title": "LC Title", "source": "/lc/path"}
with patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
) as mock_create:
mock_vectorstore = MagicMock()
mock_vectorstore.search.return_value = [mock_doc]
mock_create.return_value = mock_vectorstore
result = resource._search_vectorstores("test query", ["source_id"], 5)
assert len(result) == 1
assert result[0]["text"] == "Langchain content"
assert result[0]["title"] == "LC Title"
def test_respects_chunks_limit(self, mock_mongo_db, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
mock_docs = [
{"text": f"Content {i}", "metadata": {"title": f"Title {i}"}}
for i in range(10)
]
with patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
) as mock_create:
mock_vectorstore = MagicMock()
mock_vectorstore.search.return_value = mock_docs
mock_create.return_value = mock_vectorstore
result = resource._search_vectorstores("test query", ["source_id"], 3)
assert len(result) == 3
def test_deduplicates_results(self, mock_mongo_db, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
duplicate_text = "Duplicate content " * 20
mock_docs = [
{"text": duplicate_text, "metadata": {"title": "Title 1"}},
{"text": duplicate_text, "metadata": {"title": "Title 2"}},
{"text": "Unique content", "metadata": {"title": "Title 3"}},
]
with patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
) as mock_create:
mock_vectorstore = MagicMock()
mock_vectorstore.search.return_value = mock_docs
mock_create.return_value = mock_vectorstore
result = resource._search_vectorstores("test query", ["source_id"], 5)
assert len(result) == 2
def test_handles_vectorstore_error_gracefully(self, mock_mongo_db, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
with patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
) as mock_create:
mock_create.side_effect = Exception("Vectorstore error")
result = resource._search_vectorstores("test query", ["source_id"], 5)
assert result == []
def test_uses_filename_as_title_fallback(self, mock_mongo_db, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
mock_doc = {
"text": "Content without title",
"metadata": {"filename": "document.pdf"},
}
with patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
) as mock_create:
mock_vectorstore = MagicMock()
mock_vectorstore.search.return_value = [mock_doc]
mock_create.return_value = mock_vectorstore
result = resource._search_vectorstores("test query", ["source_id"], 5)
assert result[0]["title"] == "document.pdf"
def test_uses_content_snippet_as_title_last_resort(self, mock_mongo_db, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
mock_doc = {
"text": "Content without any title metadata at all",
"metadata": {},
}
with patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
) as mock_create:
mock_vectorstore = MagicMock()
mock_vectorstore.search.return_value = [mock_doc]
mock_create.return_value = mock_vectorstore
result = resource._search_vectorstores("test query", ["source_id"], 5)
assert "Content without any title" in result[0]["title"]
assert result[0]["title"].endswith("...")
@pytest.mark.unit
class TestSearchEndpoint:
pass
with flask_app.app_context(), flask_app.test_request_context(
json={"question": "q", "api_key": "k"} # no chunks field
), patch(
"application.api.answer.routes.search.search",
return_value=[],
) as mock_search:
SearchResource().post()
mock_search.assert_called_once_with("k", "q", 5)
# ---------------------------------------------------------------------------
# Real-PG tests for SearchResource.
# End-to-end against a real ephemeral Postgres.
#
# These exercise the full route → service → repository → DB path, patching
# only ``VectorCreator.create_vectorstore`` (so we don't need real embeddings
# or a vector index). ``db_readonly`` is redirected at the *service* module
# since that's where the import now lives.
# ---------------------------------------------------------------------------
@@ -264,7 +121,7 @@ def _patch_search_db(conn):
yield conn
with patch(
"application.api.answer.routes.search.db_readonly", _yield
"application.services.search_service.db_readonly", _yield
):
yield
@@ -298,9 +155,7 @@ class TestSearchResourcePgConn:
def test_search_returns_results(self, pg_conn, flask_app):
from application.api.answer.routes.search import SearchResource
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.sources import (
SourcesRepository,
)
from application.storage.db.repositories.sources import SourcesRepository
src = SourcesRepository(pg_conn).create("src", user_id="u")
AgentsRepository(pg_conn).create(
@@ -315,7 +170,7 @@ class TestSearchResourcePgConn:
]
with _patch_search_db(pg_conn), patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore",
"application.services.search_service.VectorCreator.create_vectorstore",
return_value=fake_vs,
), flask_app.app_context():
with flask_app.test_request_context(
@@ -328,9 +183,7 @@ class TestSearchResourcePgConn:
def test_search_uses_extra_source_ids(self, pg_conn, flask_app):
from application.api.answer.routes.search import SearchResource
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.sources import (
SourcesRepository,
)
from application.storage.db.repositories.sources import SourcesRepository
src1 = SourcesRepository(pg_conn).create("s1", user_id="u")
src2 = SourcesRepository(pg_conn).create("s2", user_id="u")
@@ -345,7 +198,7 @@ class TestSearchResourcePgConn:
{"text": "one", "metadata": {"title": "A"}},
]
with _patch_search_db(pg_conn), patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore",
"application.services.search_service.VectorCreator.create_vectorstore",
return_value=fake_vs,
), flask_app.app_context():
with flask_app.test_request_context(
@@ -353,71 +206,3 @@ class TestSearchResourcePgConn:
):
result = SearchResource().post()
assert result.status_code == 200
def test_search_exception_returns_500(self, pg_conn, flask_app):
from application.api.answer.routes.search import SearchResource
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.sources import (
SourcesRepository,
)
src = SourcesRepository(pg_conn).create("src", user_id="u")
AgentsRepository(pg_conn).create(
"u", "a", "published",
key="err-key",
source_id=str(src["id"]),
)
with _patch_search_db(pg_conn), patch(
"application.api.answer.routes.search.SearchResource._get_sources_from_api_key",
side_effect=RuntimeError("boom"),
), flask_app.app_context():
with flask_app.test_request_context(
json={"question": "q", "api_key": "err-key"},
):
result = SearchResource().post()
assert result.status_code == 500
class TestGetSourcesFromApiKeyPg:
def test_empty_for_unknown_key(self, pg_conn, flask_app):
from application.api.answer.routes.search import SearchResource
with _patch_search_db(pg_conn), flask_app.app_context():
got = SearchResource()._get_sources_from_api_key("nope")
assert got == []
def test_returns_extra_source_ids(self, pg_conn, flask_app):
from application.api.answer.routes.search import SearchResource
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.sources import (
SourcesRepository,
)
src = SourcesRepository(pg_conn).create("s", user_id="u")
AgentsRepository(pg_conn).create(
"u", "a", "published",
key="sources-key",
extra_source_ids=[str(src["id"])],
)
with _patch_search_db(pg_conn), flask_app.app_context():
got = SearchResource()._get_sources_from_api_key("sources-key")
assert got == [str(src["id"])]
def test_falls_back_to_single_source(self, pg_conn, flask_app):
from application.api.answer.routes.search import SearchResource
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.sources import (
SourcesRepository,
)
src = SourcesRepository(pg_conn).create("s", user_id="u")
AgentsRepository(pg_conn).create(
"u", "a", "published",
key="single-key",
source_id=str(src["id"]),
)
with _patch_search_db(pg_conn), flask_app.app_context():
got = SearchResource()._get_sources_from_api_key("single-key")
assert got == [str(src["id"])]

View File

View File

@@ -0,0 +1,220 @@
"""Unit tests for application/services/search_service.py.
Tests exercise the service function in isolation — AgentsRepository is
stubbed via a patched ``db_readonly`` context manager, and
``VectorCreator.create_vectorstore`` is patched to return a fake
vectorstore. No Flask app context, no real DB, no real embeddings.
"""
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
from application.services.search_service import (
InvalidAPIKey,
SearchFailed,
_collect_source_ids,
search,
)
@contextmanager
def _fake_db_readonly(agent_data):
"""Patch ``db_readonly`` so ``AgentsRepository.find_by_key`` returns ``agent_data``."""
agents_repo = MagicMock()
agents_repo.find_by_key.return_value = agent_data
@contextmanager
def _yield_conn():
yield MagicMock()
with patch(
"application.services.search_service.db_readonly", _yield_conn
), patch(
"application.services.search_service.AgentsRepository",
return_value=agents_repo,
):
yield
@pytest.mark.unit
class TestCollectSourceIds:
def test_empty_when_no_sources(self):
assert _collect_source_ids({}) == []
def test_returns_extra_source_ids(self):
agent = {"extra_source_ids": ["s1", "s2"], "source_id": "legacy"}
assert _collect_source_ids(agent) == ["s1", "s2"]
def test_falls_back_to_single_source_id(self):
agent = {"extra_source_ids": [], "source_id": "s1"}
assert _collect_source_ids(agent) == ["s1"]
def test_skips_empty_entries_in_extra(self):
agent = {"extra_source_ids": ["", None, "s1"], "source_id": "fallback"}
assert _collect_source_ids(agent) == ["s1"]
@pytest.mark.unit
class TestSearchInvalidAPIKey:
def test_raises_when_key_unknown(self):
with _fake_db_readonly(None):
with pytest.raises(InvalidAPIKey):
search("does-not-exist", "hello", 5)
def test_raises_search_failed_on_db_error(self):
@contextmanager
def _yield_conn():
yield MagicMock()
agents_repo = MagicMock()
agents_repo.find_by_key.side_effect = RuntimeError("db down")
with patch(
"application.services.search_service.db_readonly", _yield_conn
), patch(
"application.services.search_service.AgentsRepository",
return_value=agents_repo,
):
with pytest.raises(SearchFailed):
search("any-key", "hello", 5)
@pytest.mark.unit
class TestSearchEmptyWhenNoSources:
def test_returns_empty_when_agent_has_no_sources(self):
with _fake_db_readonly({"extra_source_ids": [], "source_id": None}):
assert search("k", "q", 5) == []
@pytest.mark.unit
class TestSearchResults:
def test_returns_hit_shape(self):
agent = {"source_id": "src-1", "extra_source_ids": []}
fake_vs = MagicMock()
fake_vs.search.return_value = [
{
"text": "Test content",
"metadata": {"title": "Test Title", "source": "/path/to/doc"},
}
]
with _fake_db_readonly(agent), patch(
"application.services.search_service.VectorCreator.create_vectorstore",
return_value=fake_vs,
):
results = search("k", "q", 5)
assert results == [
{"text": "Test content", "title": "Test Title", "source": "/path/to/doc"}
]
def test_handles_langchain_document_format(self):
agent = {"source_id": "src-1", "extra_source_ids": []}
lc_doc = MagicMock()
lc_doc.page_content = "Langchain content"
lc_doc.metadata = {"title": "LC Title", "source": "/lc/path"}
fake_vs = MagicMock()
fake_vs.search.return_value = [lc_doc]
with _fake_db_readonly(agent), patch(
"application.services.search_service.VectorCreator.create_vectorstore",
return_value=fake_vs,
):
results = search("k", "q", 5)
assert len(results) == 1
assert results[0]["text"] == "Langchain content"
assert results[0]["title"] == "LC Title"
def test_respects_chunks_cap(self):
agent = {"source_id": "src-1", "extra_source_ids": []}
docs = [
{"text": f"Content {i}", "metadata": {"title": f"T{i}"}}
for i in range(10)
]
fake_vs = MagicMock()
fake_vs.search.return_value = docs
with _fake_db_readonly(agent), patch(
"application.services.search_service.VectorCreator.create_vectorstore",
return_value=fake_vs,
):
results = search("k", "q", 3)
assert len(results) == 3
def test_deduplicates_results_by_content_prefix(self):
agent = {"source_id": "src-1", "extra_source_ids": []}
dup_text = "Duplicate content " * 20
docs = [
{"text": dup_text, "metadata": {"title": "T1"}},
{"text": dup_text, "metadata": {"title": "T2"}},
{"text": "Unique content", "metadata": {"title": "T3"}},
]
fake_vs = MagicMock()
fake_vs.search.return_value = docs
with _fake_db_readonly(agent), patch(
"application.services.search_service.VectorCreator.create_vectorstore",
return_value=fake_vs,
):
results = search("k", "q", 5)
assert len(results) == 2
def test_skips_broken_source_and_returns_from_healthy_ones(self):
# Two sources — the first raises, the second returns a doc. The
# caller should still get the healthy source's result.
agent = {"extra_source_ids": ["broken", "ok"], "source_id": None}
healthy_vs = MagicMock()
healthy_vs.search.return_value = [
{"text": "ok content", "metadata": {"title": "Ok"}}
]
def create_vs(store, source_id, key):
if source_id == "broken":
raise RuntimeError("vector index missing")
return healthy_vs
with _fake_db_readonly(agent), patch(
"application.services.search_service.VectorCreator.create_vectorstore",
side_effect=create_vs,
):
results = search("k", "q", 5)
assert len(results) == 1
assert results[0]["text"] == "ok content"
def test_uses_filename_when_title_missing(self):
agent = {"source_id": "src-1", "extra_source_ids": []}
fake_vs = MagicMock()
fake_vs.search.return_value = [
{"text": "body", "metadata": {"filename": "document.pdf"}}
]
with _fake_db_readonly(agent), patch(
"application.services.search_service.VectorCreator.create_vectorstore",
return_value=fake_vs,
):
results = search("k", "q", 5)
assert results[0]["title"] == "document.pdf"
def test_uses_content_snippet_as_title_last_resort(self):
agent = {"source_id": "src-1", "extra_source_ids": []}
fake_vs = MagicMock()
fake_vs.search.return_value = [
{"text": "Content without any title metadata at all", "metadata": {}}
]
with _fake_db_readonly(agent), patch(
"application.services.search_service.VectorCreator.create_vectorstore",
return_value=fake_vs,
):
results = search("k", "q", 5)
assert results[0]["title"].endswith("...")
assert "Content without any title" in results[0]["title"]
def test_skips_empty_source_ids(self):
# ``source_id=" "`` only — after strip() this leaves no real source.
agent = {"extra_source_ids": [" ", ""], "source_id": None}
with _fake_db_readonly(agent), patch(
"application.services.search_service.VectorCreator.create_vectorstore"
) as mock_create:
results = search("k", "q", 5)
mock_create.assert_not_called()
assert results == []