mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Merge pull request #1906 from arc53/feat/pg-vector
feat: implement PGVectorStore for PostgreSQL vector storage
This commit is contained in:
@@ -62,12 +62,15 @@ user_logs_collection = db["user_logs"]
|
|||||||
user_tools_collection = db["user_tools"]
|
user_tools_collection = db["user_tools"]
|
||||||
attachments_collection = db["attachments"]
|
attachments_collection = db["attachments"]
|
||||||
|
|
||||||
agents_collection.create_index(
|
try:
|
||||||
[("shared", 1)],
|
agents_collection.create_index(
|
||||||
|
[("shared_publicly", 1)],
|
||||||
name="shared_index",
|
name="shared_index",
|
||||||
background=True,
|
background=True,
|
||||||
)
|
)
|
||||||
users_collection.create_index("user_id", unique=True)
|
users_collection.create_index("user_id", unique=True)
|
||||||
|
except Exception as e:
|
||||||
|
print("Error creating indexes:", e)
|
||||||
|
|
||||||
user = Blueprint("user", __name__)
|
user = Blueprint("user", __name__)
|
||||||
user_ns = Namespace("user", description="User related operations", path="/")
|
user_ns = Namespace("user", description="User related operations", path="/")
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class Settings(BaseSettings):
|
|||||||
}
|
}
|
||||||
UPLOAD_FOLDER: str = "inputs"
|
UPLOAD_FOLDER: str = "inputs"
|
||||||
PARSE_PDF_AS_IMAGE: bool = False
|
PARSE_PDF_AS_IMAGE: bool = False
|
||||||
|
PARSE_IMAGE_REMOTE: bool = False
|
||||||
VECTOR_STORE: str = (
|
VECTOR_STORE: str = (
|
||||||
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
|
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
|
||||||
)
|
)
|
||||||
@@ -89,6 +90,8 @@ class Settings(BaseSettings):
|
|||||||
QDRANT_PATH: Optional[str] = None
|
QDRANT_PATH: Optional[str] = None
|
||||||
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
||||||
|
|
||||||
|
# PGVector vectorstore config
|
||||||
|
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
||||||
# Milvus vectorstore config
|
# Milvus vectorstore config
|
||||||
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
|
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
|
||||||
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
|
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import requests
|
|||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
from application.parser.file.base_parser import BaseParser
|
from application.parser.file.base_parser import BaseParser
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
|
||||||
class ImageParser(BaseParser):
|
class ImageParser(BaseParser):
|
||||||
@@ -18,10 +19,13 @@ class ImageParser(BaseParser):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, list[str]]:
|
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, list[str]]:
|
||||||
|
if settings.PARSE_IMAGE_REMOTE:
|
||||||
doc2md_service = "https://llm.arc53.com/doc2md"
|
doc2md_service = "https://llm.arc53.com/doc2md"
|
||||||
# alternatively you can use local vision capable LLM
|
# alternatively you can use local vision capable LLM
|
||||||
with open(file, "rb") as file_loaded:
|
with open(file, "rb") as file_loaded:
|
||||||
files = {'file': file_loaded}
|
files = {'file': file_loaded}
|
||||||
response = requests.post(doc2md_service, files=files)
|
response = requests.post(doc2md_service, files=files)
|
||||||
data = response.json()["markdown"]
|
data = response.json()["markdown"]
|
||||||
|
else:
|
||||||
|
data = ""
|
||||||
return data
|
return data
|
||||||
|
|||||||
303
application/vectorstore/pgvector.py
Normal file
303
application/vectorstore/pgvector.py
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List, Optional, Any, Dict
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.vectorstore.base import BaseVectorStore
|
||||||
|
from application.vectorstore.document_class import Document
|
||||||
|
|
||||||
|
|
||||||
|
class PGVectorStore(BaseVectorStore):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
source_id: str = "",
|
||||||
|
embeddings_key: str = "embeddings",
|
||||||
|
table_name: str = "documents",
|
||||||
|
vector_column: str = "embedding",
|
||||||
|
text_column: str = "text",
|
||||||
|
metadata_column: str = "metadata",
|
||||||
|
connection_string: str = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Store the source_id for use in add_chunk
|
||||||
|
self._source_id = str(source_id).replace("application/indexes/", "").rstrip("/")
|
||||||
|
self._embeddings_key = embeddings_key
|
||||||
|
self._table_name = table_name
|
||||||
|
self._vector_column = vector_column
|
||||||
|
self._text_column = text_column
|
||||||
|
self._metadata_column = metadata_column
|
||||||
|
self._embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
|
||||||
|
|
||||||
|
# Use provided connection string or fall back to settings
|
||||||
|
self._connection_string = connection_string or getattr(settings, 'PGVECTOR_CONNECTION_STRING', None)
|
||||||
|
|
||||||
|
if not self._connection_string:
|
||||||
|
raise ValueError(
|
||||||
|
"PostgreSQL connection string is required. "
|
||||||
|
"Set PGVECTOR_CONNECTION_STRING in settings or pass connection_string parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import psycopg2
|
||||||
|
from psycopg2.extras import Json
|
||||||
|
import pgvector.psycopg2
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import required packages. "
|
||||||
|
"Please install with `pip install psycopg2-binary pgvector`."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._psycopg2 = psycopg2
|
||||||
|
self._Json = Json
|
||||||
|
self._pgvector = pgvector.psycopg2
|
||||||
|
self._connection = None
|
||||||
|
self._ensure_table_exists()
|
||||||
|
|
||||||
|
def _get_connection(self):
|
||||||
|
"""Get or create database connection"""
|
||||||
|
if self._connection is None or self._connection.closed:
|
||||||
|
self._connection = self._psycopg2.connect(self._connection_string)
|
||||||
|
# Register pgvector types
|
||||||
|
self._pgvector.register_vector(self._connection)
|
||||||
|
return self._connection
|
||||||
|
|
||||||
|
def _ensure_table_exists(self):
|
||||||
|
"""Create table and enable pgvector extension if they don't exist"""
|
||||||
|
conn = self._get_connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Enable pgvector extension
|
||||||
|
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||||
|
|
||||||
|
# Get embedding dimension
|
||||||
|
embedding_dim = getattr(self._embedding, 'dimension', 1536) # Default to OpenAI dimension
|
||||||
|
|
||||||
|
# Create table with vector column
|
||||||
|
create_table_query = f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
{self._text_column} TEXT NOT NULL,
|
||||||
|
{self._vector_column} vector({embedding_dim}),
|
||||||
|
{self._metadata_column} JSONB,
|
||||||
|
source_id TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
cursor.execute(create_table_query)
|
||||||
|
|
||||||
|
# Create index for vector similarity search
|
||||||
|
index_query = f"""
|
||||||
|
CREATE INDEX IF NOT EXISTS {self._table_name}_{self._vector_column}_idx
|
||||||
|
ON {self._table_name} USING ivfflat ({self._vector_column} vector_cosine_ops)
|
||||||
|
WITH (lists = 100);
|
||||||
|
"""
|
||||||
|
cursor.execute(index_query)
|
||||||
|
|
||||||
|
# Create index for source_id filtering
|
||||||
|
source_index_query = f"""
|
||||||
|
CREATE INDEX IF NOT EXISTS {self._table_name}_source_id_idx
|
||||||
|
ON {self._table_name} (source_id);
|
||||||
|
"""
|
||||||
|
cursor.execute(source_index_query)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
except Exception as e:
|
||||||
|
conn.rollback()
|
||||||
|
logging.error(f"Error creating table: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
def search(self, question: str, k: int = 2, *args, **kwargs) -> List[Document]:
|
||||||
|
"""Search for similar documents using vector similarity"""
|
||||||
|
query_vector = self._embedding.embed_query(question)
|
||||||
|
|
||||||
|
conn = self._get_connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use cosine distance for similarity search with proper vector formatting
|
||||||
|
search_query = f"""
|
||||||
|
SELECT {self._text_column}, {self._metadata_column},
|
||||||
|
({self._vector_column} <=> %s::vector) as distance
|
||||||
|
FROM {self._table_name}
|
||||||
|
WHERE source_id = %s
|
||||||
|
ORDER BY {self._vector_column} <=> %s::vector
|
||||||
|
LIMIT %s;
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor.execute(search_query, (query_vector, self._source_id, query_vector, k))
|
||||||
|
results = cursor.fetchall()
|
||||||
|
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for text, metadata, distance in results:
|
||||||
|
metadata = metadata or {}
|
||||||
|
documents.append(Document(page_content=text, metadata=metadata))
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error searching documents: {e}", exc_info=True)
|
||||||
|
return []
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Add texts with their embeddings to the vector store"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
embeddings = self._embedding.embed_documents(texts)
|
||||||
|
metadatas = metadatas or [{}] * len(texts)
|
||||||
|
|
||||||
|
conn = self._get_connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
insert_query = f"""
|
||||||
|
INSERT INTO {self._table_name} ({self._text_column}, {self._vector_column}, {self._metadata_column}, source_id)
|
||||||
|
VALUES (%s, %s, %s, %s)
|
||||||
|
RETURNING id;
|
||||||
|
"""
|
||||||
|
|
||||||
|
inserted_ids = []
|
||||||
|
for text, embedding, metadata in zip(texts, embeddings, metadatas):
|
||||||
|
cursor.execute(
|
||||||
|
insert_query,
|
||||||
|
(text, embedding, self._Json(metadata), self._source_id)
|
||||||
|
)
|
||||||
|
inserted_id = cursor.fetchone()[0]
|
||||||
|
inserted_ids.append(str(inserted_id))
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
return inserted_ids
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
conn.rollback()
|
||||||
|
logging.error(f"Error adding texts: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
def delete_index(self, *args, **kwargs):
|
||||||
|
"""Delete all documents for this source_id"""
|
||||||
|
conn = self._get_connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_query = f"DELETE FROM {self._table_name} WHERE source_id = %s;"
|
||||||
|
cursor.execute(delete_query, (self._source_id,))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
conn.rollback()
|
||||||
|
logging.error(f"Error deleting index: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
def save_local(self, *args, **kwargs):
|
||||||
|
"""No-op for PostgreSQL - data is already persisted"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_chunks(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Get all chunks for this source_id"""
|
||||||
|
conn = self._get_connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
select_query = f"""
|
||||||
|
SELECT id, {self._text_column}, {self._metadata_column}
|
||||||
|
FROM {self._table_name}
|
||||||
|
WHERE source_id = %s;
|
||||||
|
"""
|
||||||
|
cursor.execute(select_query, (self._source_id,))
|
||||||
|
results = cursor.fetchall()
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for doc_id, text, metadata in results:
|
||||||
|
chunks.append({
|
||||||
|
"doc_id": str(doc_id),
|
||||||
|
"text": text,
|
||||||
|
"metadata": metadata or {}
|
||||||
|
})
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error getting chunks: {e}")
|
||||||
|
return []
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
def add_chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> str:
|
||||||
|
"""Add a single chunk to the vector store"""
|
||||||
|
metadata = metadata or {}
|
||||||
|
|
||||||
|
# Create a copy to avoid modifying the original metadata
|
||||||
|
final_metadata = metadata.copy()
|
||||||
|
|
||||||
|
# Ensure the source_id is in the metadata so the chunk can be found by filters
|
||||||
|
final_metadata["source_id"] = self._source_id
|
||||||
|
|
||||||
|
embeddings = self._embedding.embed_documents([text])
|
||||||
|
|
||||||
|
if not embeddings:
|
||||||
|
raise ValueError("Could not generate embedding for chunk")
|
||||||
|
|
||||||
|
conn = self._get_connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
insert_query = f"""
|
||||||
|
INSERT INTO {self._table_name} ({self._text_column}, {self._vector_column}, {self._metadata_column}, source_id)
|
||||||
|
VALUES (%s, %s, %s, %s)
|
||||||
|
RETURNING id;
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
insert_query,
|
||||||
|
(text, embeddings[0], self._Json(final_metadata), self._source_id)
|
||||||
|
)
|
||||||
|
inserted_id = cursor.fetchone()[0]
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
return str(inserted_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
conn.rollback()
|
||||||
|
logging.error(f"Error adding chunk: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
def delete_chunk(self, chunk_id: str) -> bool:
|
||||||
|
"""Delete a specific chunk by its ID"""
|
||||||
|
conn = self._get_connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_query = f"DELETE FROM {self._table_name} WHERE id = %s AND source_id = %s;"
|
||||||
|
cursor.execute(delete_query, (int(chunk_id), self._source_id))
|
||||||
|
deleted_count = cursor.rowcount
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
return deleted_count > 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
conn.rollback()
|
||||||
|
logging.error(f"Error deleting chunk: {e}")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Close database connection when object is destroyed"""
|
||||||
|
if hasattr(self, '_connection') and self._connection and not self._connection.closed:
|
||||||
|
self._connection.close()
|
||||||
@@ -3,6 +3,7 @@ from application.vectorstore.elasticsearch import ElasticsearchStore
|
|||||||
from application.vectorstore.milvus import MilvusStore
|
from application.vectorstore.milvus import MilvusStore
|
||||||
from application.vectorstore.mongodb import MongoDBVectorStore
|
from application.vectorstore.mongodb import MongoDBVectorStore
|
||||||
from application.vectorstore.qdrant import QdrantStore
|
from application.vectorstore.qdrant import QdrantStore
|
||||||
|
from application.vectorstore.pgvector import PGVectorStore
|
||||||
|
|
||||||
|
|
||||||
class VectorCreator:
|
class VectorCreator:
|
||||||
@@ -12,6 +13,7 @@ class VectorCreator:
|
|||||||
"mongodb": MongoDBVectorStore,
|
"mongodb": MongoDBVectorStore,
|
||||||
"qdrant": QdrantStore,
|
"qdrant": QdrantStore,
|
||||||
"milvus": MilvusStore,
|
"milvus": MilvusStore,
|
||||||
|
"pgvector": PGVectorStore
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user