Files
DocsGPT/tests/vectorstore/test_pgvector.py
2026-04-12 12:35:23 +01:00

297 lines
9.9 KiB
Python

from unittest.mock import MagicMock, Mock, patch
import pytest
def _make_store(
source_id="test-source",
embeddings_key="key",
connection_string="postgresql://user:pass@localhost/db",
):
"""Helper to create a PGVectorStore with all external deps mocked."""
with patch(
"application.vectorstore.base.BaseVectorStore._get_embeddings"
) as mock_get_emb, patch(
"application.vectorstore.pgvector.settings"
) as mock_settings, patch.dict(
"sys.modules",
{
"psycopg": MagicMock(),
"pgvector": MagicMock(),
"pgvector.psycopg": MagicMock(),
},
):
mock_emb = Mock()
mock_emb.embed_query = Mock(return_value=[0.1, 0.2, 0.3])
mock_emb.embed_documents = Mock(return_value=[[0.1, 0.2, 0.3]])
mock_emb.dimension = 768
mock_get_emb.return_value = mock_emb
mock_settings.EMBEDDINGS_NAME = "test_model"
mock_settings.PGVECTOR_CONNECTION_STRING = connection_string
from application.vectorstore.pgvector import PGVectorStore
# Patch _ensure_table_exists to avoid DB calls during init
with patch.object(PGVectorStore, "_ensure_table_exists"):
store = PGVectorStore(
source_id=source_id,
embeddings_key=embeddings_key,
connection_string=connection_string,
)
# Provide a mock connection
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value = mock_cursor
mock_conn.closed = False
store._connection = mock_conn
return store, mock_conn, mock_cursor, mock_emb
@pytest.mark.unit
class TestPGVectorStoreInit:
def test_source_id_cleaned(self):
store, _, _, _ = _make_store(source_id="application/indexes/abc123/")
assert store._source_id == "abc123"
def test_missing_connection_string_raises(self):
with patch(
"application.vectorstore.base.BaseVectorStore._get_embeddings"
) as mock_get_emb, patch(
"application.vectorstore.pgvector.settings"
) as mock_settings, patch.dict(
"sys.modules",
{
"psycopg": MagicMock(),
"pgvector": MagicMock(),
"pgvector.psycopg": MagicMock(),
},
):
mock_get_emb.return_value = Mock(dimension=768)
mock_settings.EMBEDDINGS_NAME = "test_model"
mock_settings.PGVECTOR_CONNECTION_STRING = None
mock_settings.POSTGRES_URI = None
from application.vectorstore.pgvector import PGVectorStore
with pytest.raises(ValueError, match="connection string is required"):
PGVectorStore(
source_id="test", embeddings_key="key", connection_string=None
)
@pytest.mark.unit
class TestPGVectorStoreSearch:
def test_search_returns_documents(self):
store, mock_conn, mock_cursor, mock_emb = _make_store()
mock_cursor.fetchall.return_value = [
("hello world", {"source": "test.txt"}, 0.1),
("foo bar", {"source": "test2.txt"}, 0.2),
]
results = store.search("query", k=2)
mock_emb.embed_query.assert_called_once_with("query")
assert len(results) == 2
assert results[0].page_content == "hello world"
assert results[0].metadata == {"source": "test.txt"}
def test_search_returns_empty_on_error(self):
store, mock_conn, mock_cursor, _ = _make_store()
mock_cursor.execute.side_effect = Exception("connection lost")
results = store.search("query")
assert results == []
def test_search_handles_null_metadata(self):
store, _, mock_cursor, _ = _make_store()
mock_cursor.fetchall.return_value = [("text", None, 0.5)]
results = store.search("query")
assert len(results) == 1
assert results[0].metadata == {}
@pytest.mark.unit
class TestPGVectorStoreAddTexts:
def test_add_texts_inserts_and_returns_ids(self):
store, mock_conn, mock_cursor, mock_emb = _make_store()
mock_emb.embed_documents.return_value = [[0.1, 0.2], [0.3, 0.4]]
mock_cursor.fetchone.side_effect = [(1,), (2,)]
ids = store.add_texts(["text1", "text2"], [{"a": 1}, {"b": 2}])
assert ids == ["1", "2"]
assert mock_cursor.execute.call_count == 2
mock_conn.commit.assert_called_once()
def test_add_texts_empty_returns_empty(self):
store, _, _, _ = _make_store()
assert store.add_texts([]) == []
def test_add_texts_default_metadatas(self):
store, mock_conn, mock_cursor, mock_emb = _make_store()
mock_emb.embed_documents.return_value = [[0.1, 0.2]]
mock_cursor.fetchone.return_value = (1,)
ids = store.add_texts(["text1"])
assert ids == ["1"]
def test_add_texts_rolls_back_on_error(self):
store, mock_conn, mock_cursor, mock_emb = _make_store()
mock_emb.embed_documents.return_value = [[0.1]]
mock_cursor.execute.side_effect = Exception("insert failed")
with pytest.raises(Exception, match="insert failed"):
store.add_texts(["text1"])
mock_conn.rollback.assert_called_once()
@pytest.mark.unit
class TestPGVectorStoreDeleteIndex:
def test_delete_index_deletes_by_source_id(self):
store, mock_conn, mock_cursor, _ = _make_store(source_id="src123")
store.delete_index()
mock_cursor.execute.assert_called_once()
sql = mock_cursor.execute.call_args[0][0]
assert "DELETE FROM" in sql
assert mock_cursor.execute.call_args[0][1] == ("src123",)
mock_conn.commit.assert_called_once()
def test_delete_index_rolls_back_on_error(self):
store, mock_conn, mock_cursor, _ = _make_store()
mock_cursor.execute.side_effect = Exception("fail")
with pytest.raises(Exception):
store.delete_index()
mock_conn.rollback.assert_called_once()
@pytest.mark.unit
class TestPGVectorStoreSaveLocal:
def test_save_local_is_noop(self):
store, _, _, _ = _make_store()
assert store.save_local() is None
@pytest.mark.unit
class TestPGVectorStoreGetChunks:
def test_get_chunks(self):
store, _, mock_cursor, _ = _make_store()
mock_cursor.fetchall.return_value = [
(1, "text1", {"key": "val"}),
(2, "text2", None),
]
chunks = store.get_chunks()
assert len(chunks) == 2
assert chunks[0] == {"doc_id": "1", "text": "text1", "metadata": {"key": "val"}}
assert chunks[1] == {"doc_id": "2", "text": "text2", "metadata": {}}
def test_get_chunks_returns_empty_on_error(self):
store, _, mock_cursor, _ = _make_store()
mock_cursor.execute.side_effect = Exception("fail")
assert store.get_chunks() == []
@pytest.mark.unit
class TestPGVectorStoreAddChunk:
def test_add_chunk(self):
store, mock_conn, mock_cursor, mock_emb = _make_store(source_id="src1")
mock_emb.embed_documents.return_value = [[0.1, 0.2]]
mock_cursor.fetchone.return_value = (42,)
chunk_id = store.add_chunk("hello", metadata={"key": "val"})
assert chunk_id == "42"
mock_conn.commit.assert_called_once()
def test_add_chunk_raises_on_empty_embedding(self):
store, _, _, mock_emb = _make_store()
mock_emb.embed_documents.return_value = []
with pytest.raises(ValueError, match="Could not generate embedding"):
store.add_chunk("text")
def test_add_chunk_includes_source_id_in_metadata(self):
store, mock_conn, mock_cursor, mock_emb = _make_store(source_id="src1")
mock_emb.embed_documents.return_value = [[0.1, 0.2]]
mock_cursor.fetchone.return_value = (1,)
store.add_chunk("hello", metadata={"key": "val"})
# Verify source_id is passed as a parameter to the INSERT
insert_call = mock_cursor.execute.call_args
params = insert_call[0][1]
# source_id is the 4th param in the insert
assert params[3] == "src1"
@pytest.mark.unit
class TestPGVectorStoreDeleteChunk:
def test_delete_chunk_success(self):
store, mock_conn, mock_cursor, _ = _make_store()
mock_cursor.rowcount = 1
result = store.delete_chunk("42")
assert result is True
mock_conn.commit.assert_called_once()
def test_delete_chunk_not_found(self):
store, mock_conn, mock_cursor, _ = _make_store()
mock_cursor.rowcount = 0
result = store.delete_chunk("999")
assert result is False
def test_delete_chunk_returns_false_on_error(self):
store, _, mock_cursor, _ = _make_store()
mock_cursor.execute.side_effect = Exception("fail")
result = store.delete_chunk("42")
assert result is False
@pytest.mark.unit
class TestPGVectorStoreConnection:
def test_get_connection_creates_new_when_closed(self):
store, mock_conn, _, _ = _make_store()
mock_conn.closed = True
mock_psycopg = MagicMock()
new_conn = MagicMock()
mock_psycopg.connect.return_value = new_conn
store._psycopg = mock_psycopg
conn = store._get_connection()
mock_psycopg.connect.assert_called_once()
assert conn is new_conn
def test_get_connection_reuses_open(self):
store, mock_conn, _, _ = _make_store()
mock_conn.closed = False
conn = store._get_connection()
assert conn is mock_conn
def test_ensure_table_exists(self):
store, mock_conn, mock_cursor, _ = _make_store()
# Call _ensure_table_exists directly
store._ensure_table_exists()
# Should execute CREATE EXTENSION, CREATE TABLE, and CREATE INDEX statements
assert mock_cursor.execute.call_count >= 3
mock_conn.commit.assert_called()
def test_del_closes_connection(self):
store, mock_conn, _, _ = _make_store()
mock_conn.closed = False
store.__del__()
mock_conn.close.assert_called_once()