Files
DocsGPT/application/vectorstore/base.py
2025-12-16 13:59:17 +02:00

223 lines
8.1 KiB
Python

import logging
import os
from abc import ABC, abstractmethod
import requests
from langchain_openai import OpenAIEmbeddings
from application.core.settings import settings
class RemoteEmbeddings:
"""
Wrapper for remote embeddings API (OpenAI-compatible).
Used when EMBEDDINGS_BASE_URL is configured.
Sends requests to {base_url}/v1/embeddings in OpenAI format.
"""
def __init__(self, api_url: str, model_name: str, api_key: str = None):
self.api_url = api_url.rstrip("/")
self.model_name = model_name
self.headers = {"Content-Type": "application/json"}
if api_key:
self.headers["Authorization"] = f"Bearer {api_key}"
self.dimension = 768
def _embed(self, inputs):
"""Send embedding request to remote API in OpenAI-compatible format."""
payload = {"input": inputs}
if self.model_name:
payload["model"] = self.model_name
url = f"{self.api_url}/v1/embeddings"
response = requests.post(url, headers=self.headers, json=payload, timeout=180)
response.raise_for_status()
result = response.json()
# Handle OpenAI-compatible response format
if isinstance(result, dict):
if "error" in result:
raise ValueError(f"Remote embeddings API error: {result['error']}")
if "data" in result:
# Sort by index to ensure correct order
data = sorted(result["data"], key=lambda x: x.get("index", 0))
return [item["embedding"] for item in data]
raise ValueError(
f"Unexpected response format from remote embeddings API: {result}"
)
else:
raise ValueError(
f"Unexpected response format from remote embeddings API: {result}"
)
def embed_query(self, query: str):
"""Embed a single query string."""
embeddings_list = self._embed(query)
if (
isinstance(embeddings_list, list)
and len(embeddings_list) == 1
and isinstance(embeddings_list[0], list)
):
if self.dimension is None:
self.dimension = len(embeddings_list[0])
return embeddings_list[0]
raise ValueError(
f"Unexpected result structure after embedding query: {embeddings_list}"
)
def embed_documents(self, documents: list):
"""Embed a list of documents."""
if not documents:
return []
embeddings_list = self._embed(documents)
if self.dimension is None and embeddings_list:
self.dimension = len(embeddings_list[0])
return embeddings_list
def __call__(self, text):
if isinstance(text, str):
return self.embed_query(text)
elif isinstance(text, list):
return self.embed_documents(text)
else:
raise ValueError("Input must be a string or a list of strings")
def _get_embeddings_wrapper():
"""Lazy import of EmbeddingsWrapper to avoid loading SentenceTransformer when using remote embeddings."""
from application.vectorstore.embeddings_local import EmbeddingsWrapper
return EmbeddingsWrapper
class EmbeddingsSingleton:
_instances = {}
@staticmethod
def get_instance(embeddings_name, *args, **kwargs):
if embeddings_name not in EmbeddingsSingleton._instances:
EmbeddingsSingleton._instances[embeddings_name] = (
EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs)
)
return EmbeddingsSingleton._instances[embeddings_name]
@staticmethod
def _create_instance(embeddings_name, *args, **kwargs):
if embeddings_name == "openai_text-embedding-ada-002":
return OpenAIEmbeddings(*args, **kwargs)
# Lazy import EmbeddingsWrapper only when needed (avoids loading SentenceTransformer)
EmbeddingsWrapper = _get_embeddings_wrapper()
embeddings_factory = {
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper(
"sentence-transformers/all-mpnet-base-v2"
),
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper(
"sentence-transformers/all-mpnet-base-v2"
),
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper(
"hkunlp/instructor-large"
),
}
if embeddings_name in embeddings_factory:
return embeddings_factory[embeddings_name](*args, **kwargs)
else:
return EmbeddingsWrapper(embeddings_name, *args, **kwargs)
class BaseVectorStore(ABC):
def __init__(self):
pass
@abstractmethod
def search(self, *args, **kwargs):
"""Search for similar documents/chunks in the vectorstore"""
pass
@abstractmethod
def add_texts(self, texts, metadatas=None, *args, **kwargs):
"""Add texts with their embeddings to the vectorstore"""
pass
def delete_index(self, *args, **kwargs):
"""Delete the entire index/collection"""
pass
def save_local(self, *args, **kwargs):
"""Save vectorstore to local storage"""
pass
def get_chunks(self, *args, **kwargs):
"""Get all chunks from the vectorstore"""
pass
def add_chunk(self, text, metadata=None, *args, **kwargs):
"""Add a single chunk to the vectorstore"""
pass
def delete_chunk(self, chunk_id, *args, **kwargs):
"""Delete a specific chunk from the vectorstore"""
pass
def is_azure_configured(self):
return (
settings.OPENAI_API_BASE
and settings.OPENAI_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
)
def _get_embeddings(self, embeddings_name, embeddings_key=None):
# Check for remote embeddings first
if settings.EMBEDDINGS_BASE_URL:
logging.info(
f"Using remote embeddings API at: {settings.EMBEDDINGS_BASE_URL}"
)
cache_key = f"remote_{settings.EMBEDDINGS_BASE_URL}_{embeddings_name}"
if cache_key not in EmbeddingsSingleton._instances:
EmbeddingsSingleton._instances[cache_key] = RemoteEmbeddings(
api_url=settings.EMBEDDINGS_BASE_URL,
model_name=embeddings_name,
api_key=embeddings_key,
)
return EmbeddingsSingleton._instances[cache_key]
if embeddings_name == "openai_text-embedding-ada-002":
if self.is_azure_configured():
os.environ["OPENAI_API_TYPE"] = "azure"
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name, openai_api_key=embeddings_key
)
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
possible_paths = [
"/app/models/all-mpnet-base-v2", # Docker absolute path
"./models/all-mpnet-base-v2", # Relative path
]
local_model_path = None
for path in possible_paths:
if os.path.exists(path):
local_model_path = path
logging.info(f"Found local model at path: {path}")
break
else:
logging.info(f"Path does not exist: {path}")
if local_model_path:
embedding_instance = EmbeddingsSingleton.get_instance(
local_model_path,
)
else:
logging.warning(
f"Local model not found in any of the paths: {possible_paths}. Falling back to HuggingFace download."
)
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
return embedding_instance