Merge pull request #1906 from arc53/feat/pg-vector

feat: implement PGVectorStore for PostgreSQL vector storage
This commit is contained in:
Alex
2025-08-06 10:42:34 +01:00
committed by GitHub
5 changed files with 328 additions and 13 deletions

View File

@@ -62,12 +62,15 @@ user_logs_collection = db["user_logs"]
user_tools_collection = db["user_tools"] user_tools_collection = db["user_tools"]
attachments_collection = db["attachments"] attachments_collection = db["attachments"]
agents_collection.create_index( try:
[("shared", 1)], agents_collection.create_index(
[("shared_publicly", 1)],
name="shared_index", name="shared_index",
background=True, background=True,
) )
users_collection.create_index("user_id", unique=True) users_collection.create_index("user_id", unique=True)
except Exception as e:
print("Error creating indexes:", e)
user = Blueprint("user", __name__) user = Blueprint("user", __name__)
user_ns = Namespace("user", description="User related operations", path="/") user_ns = Namespace("user", description="User related operations", path="/")

View File

@@ -30,6 +30,7 @@ class Settings(BaseSettings):
} }
UPLOAD_FOLDER: str = "inputs" UPLOAD_FOLDER: str = "inputs"
PARSE_PDF_AS_IMAGE: bool = False PARSE_PDF_AS_IMAGE: bool = False
PARSE_IMAGE_REMOTE: bool = False
VECTOR_STORE: str = ( VECTOR_STORE: str = (
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
) )
@@ -89,6 +90,8 @@ class Settings(BaseSettings):
QDRANT_PATH: Optional[str] = None QDRANT_PATH: Optional[str] = None
QDRANT_DISTANCE_FUNC: str = "Cosine" QDRANT_DISTANCE_FUNC: str = "Cosine"
# PGVector vectorstore config
PGVECTOR_CONNECTION_STRING: Optional[str] = None
# Milvus vectorstore config # Milvus vectorstore config
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt" MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default

View File

@@ -8,6 +8,7 @@ import requests
from typing import Dict, Union from typing import Dict, Union
from application.parser.file.base_parser import BaseParser from application.parser.file.base_parser import BaseParser
from application.core.settings import settings
class ImageParser(BaseParser): class ImageParser(BaseParser):
@@ -18,10 +19,13 @@ class ImageParser(BaseParser):
return {} return {}
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, list[str]]: def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, list[str]]:
if settings.PARSE_IMAGE_REMOTE:
doc2md_service = "https://llm.arc53.com/doc2md" doc2md_service = "https://llm.arc53.com/doc2md"
# alternatively you can use local vision capable LLM # alternatively you can use local vision capable LLM
with open(file, "rb") as file_loaded: with open(file, "rb") as file_loaded:
files = {'file': file_loaded} files = {'file': file_loaded}
response = requests.post(doc2md_service, files=files) response = requests.post(doc2md_service, files=files)
data = response.json()["markdown"] data = response.json()["markdown"]
else:
data = ""
return data return data

View File

@@ -0,0 +1,303 @@
import logging
from typing import List, Optional, Any, Dict
from application.core.settings import settings
from application.vectorstore.base import BaseVectorStore
from application.vectorstore.document_class import Document
class PGVectorStore(BaseVectorStore):
def __init__(
self,
source_id: str = "",
embeddings_key: str = "embeddings",
table_name: str = "documents",
vector_column: str = "embedding",
text_column: str = "text",
metadata_column: str = "metadata",
connection_string: str = None,
):
super().__init__()
# Store the source_id for use in add_chunk
self._source_id = str(source_id).replace("application/indexes/", "").rstrip("/")
self._embeddings_key = embeddings_key
self._table_name = table_name
self._vector_column = vector_column
self._text_column = text_column
self._metadata_column = metadata_column
self._embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
# Use provided connection string or fall back to settings
self._connection_string = connection_string or getattr(settings, 'PGVECTOR_CONNECTION_STRING', None)
if not self._connection_string:
raise ValueError(
"PostgreSQL connection string is required. "
"Set PGVECTOR_CONNECTION_STRING in settings or pass connection_string parameter."
)
try:
import psycopg2
from psycopg2.extras import Json
import pgvector.psycopg2
except ImportError:
raise ImportError(
"Could not import required packages. "
"Please install with `pip install psycopg2-binary pgvector`."
)
self._psycopg2 = psycopg2
self._Json = Json
self._pgvector = pgvector.psycopg2
self._connection = None
self._ensure_table_exists()
def _get_connection(self):
"""Get or create database connection"""
if self._connection is None or self._connection.closed:
self._connection = self._psycopg2.connect(self._connection_string)
# Register pgvector types
self._pgvector.register_vector(self._connection)
return self._connection
def _ensure_table_exists(self):
"""Create table and enable pgvector extension if they don't exist"""
conn = self._get_connection()
cursor = conn.cursor()
try:
# Enable pgvector extension
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
# Get embedding dimension
embedding_dim = getattr(self._embedding, 'dimension', 1536) # Default to OpenAI dimension
# Create table with vector column
create_table_query = f"""
CREATE TABLE IF NOT EXISTS {self._table_name} (
id SERIAL PRIMARY KEY,
{self._text_column} TEXT NOT NULL,
{self._vector_column} vector({embedding_dim}),
{self._metadata_column} JSONB,
source_id TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
"""
cursor.execute(create_table_query)
# Create index for vector similarity search
index_query = f"""
CREATE INDEX IF NOT EXISTS {self._table_name}_{self._vector_column}_idx
ON {self._table_name} USING ivfflat ({self._vector_column} vector_cosine_ops)
WITH (lists = 100);
"""
cursor.execute(index_query)
# Create index for source_id filtering
source_index_query = f"""
CREATE INDEX IF NOT EXISTS {self._table_name}_source_id_idx
ON {self._table_name} (source_id);
"""
cursor.execute(source_index_query)
conn.commit()
except Exception as e:
conn.rollback()
logging.error(f"Error creating table: {e}")
raise
finally:
cursor.close()
def search(self, question: str, k: int = 2, *args, **kwargs) -> List[Document]:
"""Search for similar documents using vector similarity"""
query_vector = self._embedding.embed_query(question)
conn = self._get_connection()
cursor = conn.cursor()
try:
# Use cosine distance for similarity search with proper vector formatting
search_query = f"""
SELECT {self._text_column}, {self._metadata_column},
({self._vector_column} <=> %s::vector) as distance
FROM {self._table_name}
WHERE source_id = %s
ORDER BY {self._vector_column} <=> %s::vector
LIMIT %s;
"""
cursor.execute(search_query, (query_vector, self._source_id, query_vector, k))
results = cursor.fetchall()
documents = []
for text, metadata, distance in results:
metadata = metadata or {}
documents.append(Document(page_content=text, metadata=metadata))
return documents
except Exception as e:
logging.error(f"Error searching documents: {e}", exc_info=True)
return []
finally:
cursor.close()
def add_texts(
self,
texts: List[str],
metadatas: Optional[List[Dict[str, Any]]] = None,
*args,
**kwargs,
) -> List[str]:
"""Add texts with their embeddings to the vector store"""
if not texts:
return []
embeddings = self._embedding.embed_documents(texts)
metadatas = metadatas or [{}] * len(texts)
conn = self._get_connection()
cursor = conn.cursor()
try:
insert_query = f"""
INSERT INTO {self._table_name} ({self._text_column}, {self._vector_column}, {self._metadata_column}, source_id)
VALUES (%s, %s, %s, %s)
RETURNING id;
"""
inserted_ids = []
for text, embedding, metadata in zip(texts, embeddings, metadatas):
cursor.execute(
insert_query,
(text, embedding, self._Json(metadata), self._source_id)
)
inserted_id = cursor.fetchone()[0]
inserted_ids.append(str(inserted_id))
conn.commit()
return inserted_ids
except Exception as e:
conn.rollback()
logging.error(f"Error adding texts: {e}")
raise
finally:
cursor.close()
def delete_index(self, *args, **kwargs):
"""Delete all documents for this source_id"""
conn = self._get_connection()
cursor = conn.cursor()
try:
delete_query = f"DELETE FROM {self._table_name} WHERE source_id = %s;"
cursor.execute(delete_query, (self._source_id,))
conn.commit()
except Exception as e:
conn.rollback()
logging.error(f"Error deleting index: {e}")
raise
finally:
cursor.close()
def save_local(self, *args, **kwargs):
"""No-op for PostgreSQL - data is already persisted"""
pass
def get_chunks(self) -> List[Dict[str, Any]]:
"""Get all chunks for this source_id"""
conn = self._get_connection()
cursor = conn.cursor()
try:
select_query = f"""
SELECT id, {self._text_column}, {self._metadata_column}
FROM {self._table_name}
WHERE source_id = %s;
"""
cursor.execute(select_query, (self._source_id,))
results = cursor.fetchall()
chunks = []
for doc_id, text, metadata in results:
chunks.append({
"doc_id": str(doc_id),
"text": text,
"metadata": metadata or {}
})
return chunks
except Exception as e:
logging.error(f"Error getting chunks: {e}")
return []
finally:
cursor.close()
def add_chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> str:
"""Add a single chunk to the vector store"""
metadata = metadata or {}
# Create a copy to avoid modifying the original metadata
final_metadata = metadata.copy()
# Ensure the source_id is in the metadata so the chunk can be found by filters
final_metadata["source_id"] = self._source_id
embeddings = self._embedding.embed_documents([text])
if not embeddings:
raise ValueError("Could not generate embedding for chunk")
conn = self._get_connection()
cursor = conn.cursor()
try:
insert_query = f"""
INSERT INTO {self._table_name} ({self._text_column}, {self._vector_column}, {self._metadata_column}, source_id)
VALUES (%s, %s, %s, %s)
RETURNING id;
"""
cursor.execute(
insert_query,
(text, embeddings[0], self._Json(final_metadata), self._source_id)
)
inserted_id = cursor.fetchone()[0]
conn.commit()
return str(inserted_id)
except Exception as e:
conn.rollback()
logging.error(f"Error adding chunk: {e}")
raise
finally:
cursor.close()
def delete_chunk(self, chunk_id: str) -> bool:
"""Delete a specific chunk by its ID"""
conn = self._get_connection()
cursor = conn.cursor()
try:
delete_query = f"DELETE FROM {self._table_name} WHERE id = %s AND source_id = %s;"
cursor.execute(delete_query, (int(chunk_id), self._source_id))
deleted_count = cursor.rowcount
conn.commit()
return deleted_count > 0
except Exception as e:
conn.rollback()
logging.error(f"Error deleting chunk: {e}")
return False
finally:
cursor.close()
def __del__(self):
"""Close database connection when object is destroyed"""
if hasattr(self, '_connection') and self._connection and not self._connection.closed:
self._connection.close()

View File

@@ -3,6 +3,7 @@ from application.vectorstore.elasticsearch import ElasticsearchStore
from application.vectorstore.milvus import MilvusStore from application.vectorstore.milvus import MilvusStore
from application.vectorstore.mongodb import MongoDBVectorStore from application.vectorstore.mongodb import MongoDBVectorStore
from application.vectorstore.qdrant import QdrantStore from application.vectorstore.qdrant import QdrantStore
from application.vectorstore.pgvector import PGVectorStore
class VectorCreator: class VectorCreator:
@@ -12,6 +13,7 @@ class VectorCreator:
"mongodb": MongoDBVectorStore, "mongodb": MongoDBVectorStore,
"qdrant": QdrantStore, "qdrant": QdrantStore,
"milvus": MilvusStore, "milvus": MilvusStore,
"pgvector": PGVectorStore
} }
@classmethod @classmethod