mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-01-20 14:00:55 +00:00
fix: improve remote embeds (#2193)
This commit is contained in:
@@ -12,6 +12,7 @@ class RemoteEmbeddings:
|
||||
"""
|
||||
Wrapper for remote embeddings API (OpenAI-compatible).
|
||||
Used when EMBEDDINGS_BASE_URL is configured.
|
||||
Sends requests to {base_url}/v1/embeddings in OpenAI format.
|
||||
"""
|
||||
|
||||
def __init__(self, api_url: str, model_name: str, api_key: str = None):
|
||||
@@ -20,33 +21,30 @@ class RemoteEmbeddings:
|
||||
self.headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
self.headers["Authorization"] = f"Bearer {api_key}"
|
||||
self.dimension = None
|
||||
self.dimension = 768
|
||||
|
||||
def _embed(self, inputs):
|
||||
"""Send embedding request to remote API."""
|
||||
payload = {"inputs": inputs}
|
||||
"""Send embedding request to remote API in OpenAI-compatible format."""
|
||||
payload = {"input": inputs}
|
||||
if self.model_name:
|
||||
payload["model"] = self.model_name
|
||||
|
||||
response = requests.post(
|
||||
self.api_url, headers=self.headers, json=payload, timeout=180
|
||||
)
|
||||
url = f"{self.api_url}/v1/embeddings"
|
||||
response = requests.post(url, headers=self.headers, json=payload, timeout=180)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if isinstance(result, list):
|
||||
if result and isinstance(result[0], list):
|
||||
return result
|
||||
elif result and all(isinstance(x, (int, float)) for x in result):
|
||||
return [result]
|
||||
elif not result:
|
||||
return []
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected list content from remote embeddings API: {result}"
|
||||
)
|
||||
elif isinstance(result, dict) and "error" in result:
|
||||
raise ValueError(f"Remote embeddings API error: {result['error']}")
|
||||
# Handle OpenAI-compatible response format
|
||||
if isinstance(result, dict):
|
||||
if "error" in result:
|
||||
raise ValueError(f"Remote embeddings API error: {result['error']}")
|
||||
if "data" in result:
|
||||
# Sort by index to ensure correct order
|
||||
data = sorted(result["data"], key=lambda x: x.get("index", 0))
|
||||
return [item["embedding"] for item in data]
|
||||
raise ValueError(
|
||||
f"Unexpected response format from remote embeddings API: {result}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected response format from remote embeddings API: {result}"
|
||||
|
||||
@@ -11,6 +11,7 @@ class PGVectorStore(BaseVectorStore):
|
||||
source_id: str = "",
|
||||
embeddings_key: str = "embeddings",
|
||||
table_name: str = "documents",
|
||||
decoded_token: Optional[str] = None,
|
||||
vector_column: str = "embedding",
|
||||
text_column: str = "text",
|
||||
metadata_column: str = "metadata",
|
||||
@@ -68,8 +69,7 @@ class PGVectorStore(BaseVectorStore):
|
||||
# Enable pgvector extension
|
||||
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||
|
||||
# Get embedding dimension
|
||||
embedding_dim = getattr(self._embedding, 'dimension', 1536) # Default to OpenAI dimension
|
||||
embedding_dim = getattr(self._embedding, 'dimension', 768)
|
||||
|
||||
# Create table with vector column
|
||||
create_table_query = f"""
|
||||
@@ -152,7 +152,7 @@ class PGVectorStore(BaseVectorStore):
|
||||
"""Add texts with their embeddings to the vector store"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
|
||||
embeddings = self._embedding.embed_documents(texts)
|
||||
metadatas = metadatas or [{}] * len(texts)
|
||||
|
||||
@@ -239,15 +239,13 @@ class PGVectorStore(BaseVectorStore):
|
||||
def add_chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""Add a single chunk to the vector store"""
|
||||
metadata = metadata or {}
|
||||
|
||||
# Create a copy to avoid modifying the original metadata
|
||||
|
||||
final_metadata = metadata.copy()
|
||||
|
||||
# Ensure the source_id is in the metadata so the chunk can be found by filters
|
||||
|
||||
final_metadata["source_id"] = self._source_id
|
||||
|
||||
|
||||
embeddings = self._embedding.embed_documents([text])
|
||||
|
||||
|
||||
if not embeddings:
|
||||
raise ValueError("Could not generate embedding for chunk")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user