Files
DocsGPT/tests/test_retriever.py
2026-03-29 09:46:07 +01:00

354 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from unittest.mock import MagicMock, Mock, patch
import pytest
from application.retriever.base import BaseRetriever
from application.retriever.retriever_creator import RetrieverCreator
# ── BaseRetriever ──────────────────────────────────────────────────────────────
@pytest.mark.unit
class TestBaseRetriever:
def test_cannot_instantiate_directly(self):
with pytest.raises(TypeError):
BaseRetriever()
def test_subclass_must_implement_search(self):
class Incomplete(BaseRetriever):
pass
with pytest.raises(TypeError):
Incomplete()
def test_concrete_subclass_works(self):
class Concrete(BaseRetriever):
def search(self, *args, **kwargs):
return "ok"
instance = Concrete()
assert instance.search() == "ok"
# ── RetrieverCreator ───────────────────────────────────────────────────────────
@pytest.mark.unit
class TestRetrieverCreator:
def test_create_classic(self):
mock_cls = Mock(return_value="rag_instance")
original = RetrieverCreator.retrievers.copy()
RetrieverCreator.retrievers["classic"] = mock_cls
try:
result = RetrieverCreator.create_retriever("classic", "arg1", key="val")
mock_cls.assert_called_once_with("arg1", key="val")
assert result == "rag_instance"
finally:
RetrieverCreator.retrievers.update(original)
def test_create_default(self):
mock_cls = Mock(return_value="rag_instance")
original = RetrieverCreator.retrievers.copy()
RetrieverCreator.retrievers["default"] = mock_cls
try:
result = RetrieverCreator.create_retriever("default")
mock_cls.assert_called_once_with()
assert result == "rag_instance"
finally:
RetrieverCreator.retrievers.update(original)
def test_create_none_type_uses_default(self):
mock_cls = Mock(return_value="rag_instance")
original = RetrieverCreator.retrievers.copy()
RetrieverCreator.retrievers["default"] = mock_cls
try:
result = RetrieverCreator.create_retriever(None)
mock_cls.assert_called_once()
assert result == "rag_instance"
finally:
RetrieverCreator.retrievers.update(original)
def test_case_insensitive(self):
mock_cls = Mock(return_value="rag_instance")
original = RetrieverCreator.retrievers.copy()
RetrieverCreator.retrievers["classic"] = mock_cls
try:
RetrieverCreator.create_retriever("CLASSIC")
mock_cls.assert_called_once()
finally:
RetrieverCreator.retrievers.update(original)
def test_invalid_type_raises(self):
with pytest.raises(ValueError, match="No retievers class found"):
RetrieverCreator.create_retriever("nonexistent")
# ── ClassicRAG ─────────────────────────────────────────────────────────────────
@pytest.fixture
def _patch_llm_creator(mock_llm, monkeypatch):
"""Patch LLMCreator.create_llm to return the shared mock_llm fixture."""
monkeypatch.setattr(
"application.retriever.classic_rag.LLMCreator.create_llm",
Mock(return_value=mock_llm),
)
return mock_llm
def _make_rag(source=None, _patch_llm_creator=None, **overrides):
"""Helper builds a ClassicRAG with sensible defaults."""
from application.retriever.classic_rag import ClassicRAG
defaults = dict(
source=source or {"question": "hello"},
chat_history=None,
prompt="",
chunks=2,
doc_token_limit=50000,
model_id="test-model",
user_api_key=None,
agent_id=None,
llm_name="openai",
api_key="fake",
decoded_token={"sub": "user1"},
)
defaults.update(overrides)
return ClassicRAG(**defaults)
@pytest.mark.unit
class TestClassicRAGInit:
def test_basic_init(self, _patch_llm_creator):
rag = _make_rag()
assert rag.original_question == "hello"
assert rag.chunks == 2
assert rag.vectorstores == []
def test_active_docs_as_list(self, _patch_llm_creator):
rag = _make_rag(source={"question": "q", "active_docs": ["a", "b"]})
assert rag.vectorstores == ["a", "b"]
def test_active_docs_as_string(self, _patch_llm_creator):
rag = _make_rag(source={"question": "q", "active_docs": "single"})
assert rag.vectorstores == ["single"]
def test_active_docs_none(self, _patch_llm_creator):
rag = _make_rag(source={"question": "q", "active_docs": None})
assert rag.vectorstores == []
def test_chunks_string_converted(self, _patch_llm_creator):
rag = _make_rag(chunks="5")
assert rag.chunks == 5
def test_chunks_invalid_string_defaults(self, _patch_llm_creator):
rag = _make_rag(chunks="abc")
assert rag.chunks == 2
def test_decoded_token_none(self, _patch_llm_creator):
rag = _make_rag(decoded_token=None)
assert rag.decoded_token is None
@pytest.mark.unit
class TestClassicRAGValidateVectorstore:
def test_removes_empty_ids(self, _patch_llm_creator):
rag = _make_rag(source={"question": "q", "active_docs": ["ok", "", " ", "good"]})
assert rag.vectorstores == ["ok", "good"]
def test_empty_vectorstores_no_error(self, _patch_llm_creator):
rag = _make_rag(source={"question": "q"})
assert rag.vectorstores == []
@pytest.mark.unit
class TestClassicRAGRephraseQuery:
def test_no_history_returns_original(self, _patch_llm_creator):
rag = _make_rag(
source={"question": "original", "active_docs": ["vs1"]},
chat_history=[],
)
assert rag.question == "original"
def test_no_vectorstores_returns_original(self, _patch_llm_creator):
rag = _make_rag(
source={"question": "original"},
chat_history=[{"prompt": "hi", "response": "hello"}],
)
assert rag.question == "original"
def test_chunks_zero_returns_original(self, _patch_llm_creator):
rag = _make_rag(
source={"question": "original", "active_docs": ["vs1"]},
chat_history=[{"prompt": "hi", "response": "hello"}],
chunks=0,
)
assert rag.question == "original"
def test_rephrase_called_with_history(self, _patch_llm_creator, mock_llm):
mock_llm.gen = Mock(return_value="rephrased question")
rag = _make_rag(
source={"question": "original", "active_docs": ["vs1"]},
chat_history=[{"prompt": "hi", "response": "hello"}],
)
assert rag.question == "rephrased question"
mock_llm.gen.assert_called_once()
def test_rephrase_llm_returns_empty_falls_back(self, _patch_llm_creator, mock_llm):
mock_llm.gen = Mock(return_value="")
rag = _make_rag(
source={"question": "original", "active_docs": ["vs1"]},
chat_history=[{"prompt": "hi", "response": "hello"}],
)
assert rag.question == "original"
def test_rephrase_llm_exception_falls_back(self, _patch_llm_creator, mock_llm):
mock_llm.gen = Mock(side_effect=RuntimeError("boom"))
rag = _make_rag(
source={"question": "original", "active_docs": ["vs1"]},
chat_history=[{"prompt": "hi", "response": "hello"}],
)
assert rag.question == "original"
@pytest.mark.unit
class TestClassicRAGGetData:
def test_chunks_zero_returns_empty(self, _patch_llm_creator):
rag = _make_rag(chunks=0)
assert rag._get_data() == []
def test_no_vectorstores_returns_empty(self, _patch_llm_creator):
rag = _make_rag(source={"question": "q"})
assert rag._get_data() == []
@patch("application.retriever.classic_rag.VectorCreator")
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
def test_returns_docs_with_metadata(self, mock_tokens, mock_vc, _patch_llm_creator):
mock_docsearch = MagicMock()
mock_doc = MagicMock()
mock_doc.page_content = "content here"
mock_doc.metadata = {
"title": "path/to/Title",
"filename": "/docs/file.txt",
"source": "http://example.com",
}
mock_docsearch.search.return_value = [mock_doc]
mock_vc.create_vectorstore.return_value = mock_docsearch
rag = _make_rag(source={"question": "q", "active_docs": ["vs1"]})
docs = rag._get_data()
assert len(docs) == 1
assert docs[0]["text"] == "content here"
assert docs[0]["title"] == "Title"
assert docs[0]["filename"] == "file.txt"
assert docs[0]["source"] == "http://example.com"
@patch("application.retriever.classic_rag.VectorCreator")
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
def test_dict_style_docs(self, mock_tokens, mock_vc, _patch_llm_creator):
mock_docsearch = MagicMock()
mock_docsearch.search.return_value = [
{"text": "dict content", "metadata": {"title": "Dict Title"}}
]
mock_vc.create_vectorstore.return_value = mock_docsearch
rag = _make_rag(source={"question": "q", "active_docs": ["vs1"]})
docs = rag._get_data()
assert len(docs) == 1
assert docs[0]["text"] == "dict content"
@patch("application.retriever.classic_rag.VectorCreator")
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=100000)
def test_token_budget_respected(self, mock_tokens, mock_vc, _patch_llm_creator):
mock_docsearch = MagicMock()
mock_doc = MagicMock()
mock_doc.page_content = "big content"
mock_doc.metadata = {"title": "t"}
mock_docsearch.search.return_value = [mock_doc, mock_doc, mock_doc]
mock_vc.create_vectorstore.return_value = mock_docsearch
rag = _make_rag(
source={"question": "q", "active_docs": ["vs1"]},
doc_token_limit=100,
)
docs = rag._get_data()
# tokens (100000) exceed budget (90), so no docs should be added
assert len(docs) == 0
@patch("application.retriever.classic_rag.VectorCreator")
def test_vectorstore_error_continues(self, mock_vc, _patch_llm_creator):
mock_vc.create_vectorstore.side_effect = RuntimeError("connection failed")
rag = _make_rag(source={"question": "q", "active_docs": ["vs1"]})
docs = rag._get_data()
assert docs == []
@patch("application.retriever.classic_rag.VectorCreator")
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
def test_multiple_vectorstores(self, mock_tokens, mock_vc, _patch_llm_creator):
mock_docsearch = MagicMock()
mock_doc = MagicMock()
mock_doc.page_content = "content"
mock_doc.metadata = {"title": "t", "source": "s"}
mock_docsearch.search.return_value = [mock_doc]
mock_vc.create_vectorstore.return_value = mock_docsearch
rag = _make_rag(source={"question": "q", "active_docs": ["vs1", "vs2"]})
docs = rag._get_data()
assert len(docs) == 2
@patch("application.retriever.classic_rag.VectorCreator")
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
def test_doc_missing_filename_uses_title(self, mock_tokens, mock_vc, _patch_llm_creator):
mock_docsearch = MagicMock()
mock_doc = MagicMock()
mock_doc.page_content = "content"
mock_doc.metadata = {"title": "MyTitle"}
mock_docsearch.search.return_value = [mock_doc]
mock_vc.create_vectorstore.return_value = mock_docsearch
rag = _make_rag(source={"question": "q", "active_docs": ["vs1"]})
docs = rag._get_data()
assert docs[0]["filename"] == "MyTitle"
@patch("application.retriever.classic_rag.VectorCreator")
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
def test_non_string_title_converted(self, mock_tokens, mock_vc, _patch_llm_creator):
mock_docsearch = MagicMock()
mock_doc = MagicMock()
mock_doc.page_content = "content"
mock_doc.metadata = {"title": 42}
mock_docsearch.search.return_value = [mock_doc]
mock_vc.create_vectorstore.return_value = mock_docsearch
rag = _make_rag(source={"question": "q", "active_docs": ["vs1"]})
docs = rag._get_data()
assert docs[0]["title"] == "42"
@pytest.mark.unit
class TestClassicRAGSearch:
@patch("application.retriever.classic_rag.VectorCreator")
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
def test_search_with_query_override(self, mock_tokens, mock_vc, _patch_llm_creator, mock_llm):
mock_docsearch = MagicMock()
mock_doc = MagicMock()
mock_doc.page_content = "result"
mock_doc.metadata = {"title": "t"}
mock_docsearch.search.return_value = [mock_doc]
mock_vc.create_vectorstore.return_value = mock_docsearch
mock_llm.gen = Mock(return_value="")
rag = _make_rag(source={"question": "original", "active_docs": ["vs1"]})
docs = rag.search(query="override query")
assert rag.original_question == "override query"
assert len(docs) == 1
def test_search_without_query_uses_default(self, _patch_llm_creator):
rag = _make_rag(source={"question": "q"})
docs = rag.search()
assert docs == []