diff --git a/application/api/answer/__init__.py b/application/api/answer/__init__.py index 861c922d..a10b9b5f 100644 --- a/application/api/answer/__init__.py +++ b/application/api/answer/__init__.py @@ -3,6 +3,7 @@ from flask import Blueprint from application.api import api from application.api.answer.routes.answer import AnswerResource from application.api.answer.routes.base import answer_ns +from application.api.answer.routes.search import SearchResource from application.api.answer.routes.stream import StreamResource @@ -14,6 +15,7 @@ api.add_namespace(answer_ns) def init_answer_routes(): api.add_resource(StreamResource, "/stream") api.add_resource(AnswerResource, "/api/answer") + api.add_resource(SearchResource, "/api/search") init_answer_routes() diff --git a/application/api/answer/routes/search.py b/application/api/answer/routes/search.py new file mode 100644 index 00000000..16ebdb82 --- /dev/null +++ b/application/api/answer/routes/search.py @@ -0,0 +1,186 @@ +import logging +from typing import Any, Dict, List + +from flask import make_response, request +from flask_restx import fields, Resource + +from bson.dbref import DBRef + +from application.api.answer.routes.base import answer_ns +from application.core.mongo_db import MongoDB +from application.core.settings import settings +from application.vectorstore.vector_creator import VectorCreator + +logger = logging.getLogger(__name__) + + +@answer_ns.route("/api/search") +class SearchResource(Resource): + """Fast search endpoint for retrieving relevant documents""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + mongo = MongoDB.get_client() + self.db = mongo[settings.MONGO_DB_NAME] + self.agents_collection = self.db["agents"] + + search_model = answer_ns.model( + "SearchModel", + { + "question": fields.String( + required=True, description="Search query" + ), + "api_key": fields.String( + required=True, description="API key for authentication" + ), + "chunks": fields.Integer( + required=False, default=5, description="Number of results to return" + ), + }, + ) + + def _get_sources_from_api_key(self, api_key: str) -> List[str]: + """Get source IDs connected to the API key/agent. + + """ + agent_data = self.agents_collection.find_one({"key": api_key}) + if not agent_data: + return [] + + source_ids = [] + + # Handle multiple sources (only if non-empty) + sources = agent_data.get("sources", []) + if sources and isinstance(sources, list) and len(sources) > 0: + for source_ref in sources: + # Skip "default" - it's a placeholder, not an actual vectorstore + if source_ref == "default": + continue + elif isinstance(source_ref, DBRef): + source_doc = self.db.dereference(source_ref) + if source_doc: + source_ids.append(str(source_doc["_id"])) + + # Handle single source (legacy) - check if sources was empty or didn't yield results + if not source_ids: + source = agent_data.get("source") + if isinstance(source, DBRef): + source_doc = self.db.dereference(source) + if source_doc: + source_ids.append(str(source_doc["_id"])) + # Skip "default" - it's a placeholder, not an actual vectorstore + elif source and source != "default": + source_ids.append(source) + + return source_ids + + def _search_vectorstores( + self, query: str, source_ids: List[str], chunks: int + ) -> List[Dict[str, Any]]: + """Search across vectorstores and return results""" + if not source_ids: + return [] + + results = [] + chunks_per_source = max(1, chunks // len(source_ids)) + seen_texts = set() + + for source_id in source_ids: + if not source_id or not source_id.strip(): + continue + + try: + docsearch = VectorCreator.create_vectorstore( + settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY + ) + docs = docsearch.search(query, k=chunks_per_source * 2) + + for doc in docs: + if len(results) >= chunks: + break + + if hasattr(doc, "page_content") and hasattr(doc, "metadata"): + page_content = doc.page_content + metadata = doc.metadata + else: + page_content = doc.get("text", doc.get("page_content", "")) + metadata = doc.get("metadata", {}) + + # Skip duplicates + text_hash = hash(page_content[:200]) + if text_hash in seen_texts: + continue + seen_texts.add(text_hash) + + title = metadata.get( + "title", metadata.get("post_title", "") + ) + if not isinstance(title, str): + title = str(title) if title else "" + + # Clean up title + if title: + title = title.split("/")[-1] + else: + # Use filename or first part of content as title + title = metadata.get("filename", page_content[:50] + "...") + + source = metadata.get("source", source_id) + + results.append({ + "text": page_content, + "title": title, + "source": source, + }) + + if len(results) >= chunks: + break + + except Exception as e: + logger.error( + f"Error searching vectorstore {source_id}: {e}", + exc_info=True, + ) + continue + + return results[:chunks] + + @answer_ns.expect(search_model) + @answer_ns.doc(description="Search for relevant documents based on query") + def post(self): + data = request.get_json() + + question = data.get("question") + api_key = data.get("api_key") + chunks = data.get("chunks", 5) + + if not question: + return make_response({"error": "question is required"}, 400) + + if not api_key: + return make_response({"error": "api_key is required"}, 400) + + # Validate API key + agent = self.agents_collection.find_one({"key": api_key}) + if not agent: + return make_response({"error": "Invalid API key"}, 401) + + try: + # Get sources connected to this API key + source_ids = self._get_sources_from_api_key(api_key) + + if not source_ids: + return make_response([], 200) + + # Perform search + results = self._search_vectorstores(question, source_ids, chunks) + + return make_response(results, 200) + + except Exception as e: + logger.error( + f"/api/search - error: {str(e)}", + extra={"error": str(e)}, + exc_info=True, + ) + return make_response({"error": "Search failed"}, 500) diff --git a/application/requirements.txt b/application/requirements.txt index 82763f89..a6528ca3 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -32,11 +32,11 @@ jsonpointer==3.0.0 kombu==5.6.1 langchain==1.1.3 langchain-community==0.4.1 -langchain-core==1.1.3 +langchain-core==1.2.4 langchain-openai==1.1.1 langchain-text-splitters==1.0.0 langsmith==0.4.58 -lazy-object-proxy==1.10.0 +lazy-object-proxy==1.12.0 lxml==6.0.2 markupsafe==3.0.2 marshmallow==3.26.1 @@ -55,7 +55,7 @@ pathable==0.4.4 pillow portalocker>=2.7.0,<3.0.0 prance==25.4.8.0 -prompt-toolkit==3.0.51 +prompt-toolkit==3.0.52 protobuf==6.33.2 psycopg2-binary==2.9.11 py==1.11.0 @@ -84,7 +84,7 @@ typing-inspect==0.9.0 tzdata==2025.2 urllib3==2.6.1 vine==5.1.0 -wcwidth==0.2.13 +wcwidth==0.2.14 werkzeug>=3.1.0 yarl==1.22.0 markdownify==1.2.2 diff --git a/tests/api/answer/routes/test_search.py b/tests/api/answer/routes/test_search.py new file mode 100644 index 00000000..b397cf3c --- /dev/null +++ b/tests/api/answer/routes/test_search.py @@ -0,0 +1,561 @@ +from unittest.mock import MagicMock, patch + +import pytest +from bson import ObjectId +from bson.dbref import DBRef + + +@pytest.mark.unit +class TestSearchResourceValidation: + def test_returns_error_when_question_missing(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + + with flask_app.app_context(): + with flask_app.test_request_context( + json={"api_key": "test_key"} + ): + resource = SearchResource() + result = resource.post() + + assert result.status_code == 400 + assert "question" in result.json["error"] + + def test_returns_error_when_api_key_missing(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + + with flask_app.app_context(): + with flask_app.test_request_context( + json={"question": "test query"} + ): + resource = SearchResource() + result = resource.post() + + assert result.status_code == 400 + assert "api_key" in result.json["error"] + + def test_returns_error_for_invalid_api_key(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + + with flask_app.app_context(): + with flask_app.test_request_context( + json={"question": "test query", "api_key": "invalid_key"} + ): + resource = SearchResource() + result = resource.post() + + assert result.status_code == 401 + assert "Invalid API key" in result.json["error"] + + +@pytest.mark.unit +class TestGetSourcesFromApiKey: + def test_returns_empty_list_when_agent_not_found(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + + with flask_app.app_context(): + resource = SearchResource() + + result = resource._get_sources_from_api_key("nonexistent_key") + + assert result == [] + + def test_returns_source_id_from_dbref(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + from application.core.settings import settings + + with flask_app.app_context(): + source_id = ObjectId() + agent_id = ObjectId() + + sources_collection = mock_mongo_db[settings.MONGO_DB_NAME]["sources"] + sources_collection.insert_one( + {"_id": source_id, "name": "Test Source"} + ) + + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_api_key", + "source": DBRef("sources", source_id), + "sources": [], + } + ) + + resource = SearchResource() + result = resource._get_sources_from_api_key("test_api_key") + + assert len(result) == 1 + assert result[0] == str(source_id) + + def test_returns_multiple_sources_from_sources_array( + self, mock_mongo_db, flask_app + ): + from application.api.answer.routes.search import SearchResource + from application.core.settings import settings + + with flask_app.app_context(): + source_id_1 = ObjectId() + source_id_2 = ObjectId() + agent_id = ObjectId() + + sources_collection = mock_mongo_db[settings.MONGO_DB_NAME]["sources"] + sources_collection.insert_one({"_id": source_id_1, "name": "Source 1"}) + sources_collection.insert_one({"_id": source_id_2, "name": "Source 2"}) + + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_api_key", + "sources": [ + DBRef("sources", source_id_1), + DBRef("sources", source_id_2), + ], + } + ) + + resource = SearchResource() + result = resource._get_sources_from_api_key("test_api_key") + + assert len(result) == 2 + assert str(source_id_1) in result + assert str(source_id_2) in result + + def test_skips_default_source_in_sources_array(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + from application.core.settings import settings + + with flask_app.app_context(): + source_id = ObjectId() + agent_id = ObjectId() + + sources_collection = mock_mongo_db[settings.MONGO_DB_NAME]["sources"] + sources_collection.insert_one({"_id": source_id, "name": "Test Source"}) + + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_api_key", + "sources": ["default", DBRef("sources", source_id)], + } + ) + + resource = SearchResource() + result = resource._get_sources_from_api_key("test_api_key") + + assert len(result) == 1 + assert result[0] == str(source_id) + assert "default" not in result + + def test_skips_default_source_in_legacy_field(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + from application.core.settings import settings + + with flask_app.app_context(): + agent_id = ObjectId() + + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_api_key", + "source": "default", + "sources": [], + } + ) + + resource = SearchResource() + result = resource._get_sources_from_api_key("test_api_key") + + assert result == [] + + def test_falls_back_to_legacy_source_when_sources_empty( + self, mock_mongo_db, flask_app + ): + from application.api.answer.routes.search import SearchResource + from application.core.settings import settings + + with flask_app.app_context(): + source_id = ObjectId() + agent_id = ObjectId() + + sources_collection = mock_mongo_db[settings.MONGO_DB_NAME]["sources"] + sources_collection.insert_one({"_id": source_id, "name": "Test Source"}) + + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_api_key", + "source": DBRef("sources", source_id), + "sources": [], + } + ) + + resource = SearchResource() + result = resource._get_sources_from_api_key("test_api_key") + + assert len(result) == 1 + assert result[0] == str(source_id) + + def test_handles_string_source_id(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + from application.core.settings import settings + + with flask_app.app_context(): + agent_id = ObjectId() + source_id = "custom_source_id" + + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_api_key", + "source": source_id, + "sources": [], + } + ) + + resource = SearchResource() + result = resource._get_sources_from_api_key("test_api_key") + + assert len(result) == 1 + assert result[0] == source_id + + +@pytest.mark.unit +class TestSearchVectorstores: + def test_returns_empty_when_no_source_ids(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + + with flask_app.app_context(): + resource = SearchResource() + + result = resource._search_vectorstores("test query", [], 5) + + assert result == [] + + def test_skips_empty_source_ids(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + + with flask_app.app_context(): + resource = SearchResource() + + with patch( + "application.api.answer.routes.search.VectorCreator.create_vectorstore" + ) as mock_create: + mock_vectorstore = MagicMock() + mock_vectorstore.search.return_value = [] + mock_create.return_value = mock_vectorstore + + result = resource._search_vectorstores("test query", ["", " "], 5) + + mock_create.assert_not_called() + assert result == [] + + def test_returns_search_results(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + + with flask_app.app_context(): + resource = SearchResource() + + mock_doc = { + "text": "Test content", + "page_content": "Test content", + "metadata": { + "title": "Test Title", + "source": "/path/to/doc", + }, + } + + with patch( + "application.api.answer.routes.search.VectorCreator.create_vectorstore" + ) as mock_create: + mock_vectorstore = MagicMock() + mock_vectorstore.search.return_value = [mock_doc] + mock_create.return_value = mock_vectorstore + + result = resource._search_vectorstores("test query", ["source_id"], 5) + + assert len(result) == 1 + assert result[0]["text"] == "Test content" + assert result[0]["title"] == "Test Title" + assert result[0]["source"] == "/path/to/doc" + + def test_handles_langchain_document_format(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + + with flask_app.app_context(): + resource = SearchResource() + + mock_doc = MagicMock() + mock_doc.page_content = "Langchain content" + mock_doc.metadata = {"title": "LC Title", "source": "/lc/path"} + + with patch( + "application.api.answer.routes.search.VectorCreator.create_vectorstore" + ) as mock_create: + mock_vectorstore = MagicMock() + mock_vectorstore.search.return_value = [mock_doc] + mock_create.return_value = mock_vectorstore + + result = resource._search_vectorstores("test query", ["source_id"], 5) + + assert len(result) == 1 + assert result[0]["text"] == "Langchain content" + assert result[0]["title"] == "LC Title" + + def test_respects_chunks_limit(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + + with flask_app.app_context(): + resource = SearchResource() + + mock_docs = [ + {"text": f"Content {i}", "metadata": {"title": f"Title {i}"}} + for i in range(10) + ] + + with patch( + "application.api.answer.routes.search.VectorCreator.create_vectorstore" + ) as mock_create: + mock_vectorstore = MagicMock() + mock_vectorstore.search.return_value = mock_docs + mock_create.return_value = mock_vectorstore + + result = resource._search_vectorstores("test query", ["source_id"], 3) + + assert len(result) == 3 + + def test_deduplicates_results(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + + with flask_app.app_context(): + resource = SearchResource() + + duplicate_text = "Duplicate content " * 20 + mock_docs = [ + {"text": duplicate_text, "metadata": {"title": "Title 1"}}, + {"text": duplicate_text, "metadata": {"title": "Title 2"}}, + {"text": "Unique content", "metadata": {"title": "Title 3"}}, + ] + + with patch( + "application.api.answer.routes.search.VectorCreator.create_vectorstore" + ) as mock_create: + mock_vectorstore = MagicMock() + mock_vectorstore.search.return_value = mock_docs + mock_create.return_value = mock_vectorstore + + result = resource._search_vectorstores("test query", ["source_id"], 5) + + assert len(result) == 2 + + def test_handles_vectorstore_error_gracefully(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + + with flask_app.app_context(): + resource = SearchResource() + + with patch( + "application.api.answer.routes.search.VectorCreator.create_vectorstore" + ) as mock_create: + mock_create.side_effect = Exception("Vectorstore error") + + result = resource._search_vectorstores("test query", ["source_id"], 5) + + assert result == [] + + def test_uses_filename_as_title_fallback(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + + with flask_app.app_context(): + resource = SearchResource() + + mock_doc = { + "text": "Content without title", + "metadata": {"filename": "document.pdf"}, + } + + with patch( + "application.api.answer.routes.search.VectorCreator.create_vectorstore" + ) as mock_create: + mock_vectorstore = MagicMock() + mock_vectorstore.search.return_value = [mock_doc] + mock_create.return_value = mock_vectorstore + + result = resource._search_vectorstores("test query", ["source_id"], 5) + + assert result[0]["title"] == "document.pdf" + + def test_uses_content_snippet_as_title_last_resort(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + + with flask_app.app_context(): + resource = SearchResource() + + mock_doc = { + "text": "Content without any title metadata at all", + "metadata": {}, + } + + with patch( + "application.api.answer.routes.search.VectorCreator.create_vectorstore" + ) as mock_create: + mock_vectorstore = MagicMock() + mock_vectorstore.search.return_value = [mock_doc] + mock_create.return_value = mock_vectorstore + + result = resource._search_vectorstores("test query", ["source_id"], 5) + + assert "Content without any title" in result[0]["title"] + assert result[0]["title"].endswith("...") + + +@pytest.mark.unit +class TestSearchEndpoint: + def test_returns_empty_array_when_no_sources(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + from application.core.settings import settings + + with flask_app.app_context(): + agent_id = ObjectId() + + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_api_key", + "source": "default", + "sources": [], + } + ) + + with flask_app.test_request_context( + json={"question": "test query", "api_key": "test_api_key"} + ): + resource = SearchResource() + result = resource.post() + + assert result.status_code == 200 + assert result.json == [] + + def test_returns_search_results_successfully(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + from application.core.settings import settings + + with flask_app.app_context(): + source_id = ObjectId() + agent_id = ObjectId() + + sources_collection = mock_mongo_db[settings.MONGO_DB_NAME]["sources"] + sources_collection.insert_one({"_id": source_id, "name": "Test Source"}) + + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_api_key", + "source": DBRef("sources", source_id), + "sources": [], + } + ) + + mock_doc = { + "text": "Search result content", + "metadata": {"title": "Result Title", "source": "/doc/path"}, + } + + with flask_app.test_request_context( + json={"question": "test query", "api_key": "test_api_key", "chunks": 5} + ): + with patch( + "application.api.answer.routes.search.VectorCreator.create_vectorstore" + ) as mock_create: + mock_vectorstore = MagicMock() + mock_vectorstore.search.return_value = [mock_doc] + mock_create.return_value = mock_vectorstore + + resource = SearchResource() + result = resource.post() + + assert result.status_code == 200 + assert len(result.json) == 1 + assert result.json[0]["text"] == "Search result content" + assert result.json[0]["title"] == "Result Title" + + def test_uses_default_chunks_value(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + from application.core.settings import settings + + with flask_app.app_context(): + source_id = ObjectId() + agent_id = ObjectId() + + sources_collection = mock_mongo_db[settings.MONGO_DB_NAME]["sources"] + sources_collection.insert_one({"_id": source_id, "name": "Test Source"}) + + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_api_key", + "source": DBRef("sources", source_id), + "sources": [], + } + ) + + with flask_app.test_request_context( + json={"question": "test query", "api_key": "test_api_key"} + ): + with patch( + "application.api.answer.routes.search.VectorCreator.create_vectorstore" + ) as mock_create: + mock_vectorstore = MagicMock() + mock_vectorstore.search.return_value = [] + mock_create.return_value = mock_vectorstore + + resource = SearchResource() + resource.post() + + mock_vectorstore.search.assert_called_once() + call_args = mock_vectorstore.search.call_args + assert call_args[1]["k"] == 10 + + def test_handles_internal_error(self, mock_mongo_db, flask_app): + from application.api.answer.routes.search import SearchResource + from application.core.settings import settings + + with flask_app.app_context(): + source_id = ObjectId() + agent_id = ObjectId() + + sources_collection = mock_mongo_db[settings.MONGO_DB_NAME]["sources"] + sources_collection.insert_one({"_id": source_id, "name": "Test Source"}) + + agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"] + agents_collection.insert_one( + { + "_id": agent_id, + "key": "test_api_key", + "source": DBRef("sources", source_id), + "sources": [], + } + ) + + with flask_app.test_request_context( + json={"question": "test query", "api_key": "test_api_key"} + ): + resource = SearchResource() + + with patch.object( + resource, "_get_sources_from_api_key" + ) as mock_get_sources: + mock_get_sources.side_effect = Exception("Database error") + + result = resource.post() + + assert result.status_code == 500 + assert "Search failed" in result.json["error"]