mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
feat: model registry and capabilities for multi-provider support (#2158)
* feat: Implement model registry and capabilities for multi-provider support - Added ModelRegistry to manage available models and their capabilities. - Introduced ModelProvider enum for different LLM providers. - Created ModelCapabilities dataclass to define model features. - Implemented methods to load models based on API keys and settings. - Added utility functions for model management in model_utils.py. - Updated settings.py to include provider-specific API keys. - Refactored LLM classes (Anthropic, OpenAI, Google, etc.) to utilize new model registry. - Enhanced utility functions to handle token limits and model validation. - Improved code structure and logging for better maintainability. * feat: Add model selection feature with API integration and UI component * feat: Add model selection and default model functionality in agent management * test: Update assertions and formatting in stream processing tests * refactor(llm): Standardize model identifier to model_id * fix tests --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
@@ -2,6 +2,8 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
@@ -9,20 +11,25 @@ from application.storage.storage_creator import StorageCreator
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
if (
|
||||
self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
# Priority: 1) Parameter base_url, 2) Settings OPENAI_BASE_URL, 3) Default
|
||||
effective_base_url = None
|
||||
if base_url and isinstance(base_url, str) and base_url.strip():
|
||||
effective_base_url = base_url
|
||||
elif (
|
||||
isinstance(settings.OPENAI_BASE_URL, str)
|
||||
and settings.OPENAI_BASE_URL.strip()
|
||||
):
|
||||
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
|
||||
effective_base_url = settings.OPENAI_BASE_URL
|
||||
else:
|
||||
DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=api_key, base_url=DEFAULT_OPENAI_API_BASE)
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
effective_base_url = "https://api.openai.com/v1"
|
||||
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=effective_base_url)
|
||||
self.storage = StorageCreator.get_storage()
|
||||
|
||||
def _clean_messages_openai(self, messages):
|
||||
@@ -33,7 +40,6 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
cleaned_messages.append({"role": role, "content": content})
|
||||
@@ -107,7 +113,6 @@ class OpenAILLM(BaseLLM):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
return cleaned_messages
|
||||
|
||||
def _raw_gen(
|
||||
@@ -132,10 +137,8 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
|
||||
if response_format:
|
||||
request_params["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**request_params)
|
||||
|
||||
if tools:
|
||||
@@ -165,10 +168,8 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
|
||||
if response_format:
|
||||
request_params["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**request_params)
|
||||
|
||||
try:
|
||||
@@ -194,7 +195,6 @@ class OpenAILLM(BaseLLM):
|
||||
def prepare_structured_output_format(self, json_schema):
|
||||
if not json_schema:
|
||||
return None
|
||||
|
||||
try:
|
||||
|
||||
def add_additional_properties_false(schema_obj):
|
||||
@@ -204,11 +204,11 @@ class OpenAILLM(BaseLLM):
|
||||
if schema_copy.get("type") == "object":
|
||||
schema_copy["additionalProperties"] = False
|
||||
# Ensure 'required' includes all properties for OpenAI strict mode
|
||||
|
||||
if "properties" in schema_copy:
|
||||
schema_copy["required"] = list(
|
||||
schema_copy["properties"].keys()
|
||||
)
|
||||
|
||||
for key, value in schema_copy.items():
|
||||
if key == "properties" and isinstance(value, dict):
|
||||
schema_copy[key] = {
|
||||
@@ -224,7 +224,6 @@ class OpenAILLM(BaseLLM):
|
||||
add_additional_properties_false(sub_schema)
|
||||
for sub_schema in value
|
||||
]
|
||||
|
||||
return schema_copy
|
||||
return schema_obj
|
||||
|
||||
@@ -243,7 +242,6 @@ class OpenAILLM(BaseLLM):
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error preparing structured output format: {e}")
|
||||
return None
|
||||
@@ -277,21 +275,19 @@ class OpenAILLM(BaseLLM):
|
||||
"""
|
||||
if not attachments:
|
||||
return messages
|
||||
|
||||
prepared_messages = messages.copy()
|
||||
|
||||
# Find the user message to attach file_id to the last one
|
||||
|
||||
user_message_index = None
|
||||
for i in range(len(prepared_messages) - 1, -1, -1):
|
||||
if prepared_messages[i].get("role") == "user":
|
||||
user_message_index = i
|
||||
break
|
||||
|
||||
if user_message_index is None:
|
||||
user_message = {"role": "user", "content": []}
|
||||
prepared_messages.append(user_message)
|
||||
user_message_index = len(prepared_messages) - 1
|
||||
|
||||
if isinstance(prepared_messages[user_message_index].get("content"), str):
|
||||
text_content = prepared_messages[user_message_index]["content"]
|
||||
prepared_messages[user_message_index]["content"] = [
|
||||
@@ -299,7 +295,6 @@ class OpenAILLM(BaseLLM):
|
||||
]
|
||||
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
|
||||
prepared_messages[user_message_index]["content"] = []
|
||||
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get("mime_type")
|
||||
|
||||
@@ -326,6 +321,7 @@ class OpenAILLM(BaseLLM):
|
||||
}
|
||||
)
|
||||
# Handle PDFs using the file API
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
try:
|
||||
file_id = self._upload_file_to_openai(attachment)
|
||||
@@ -341,7 +337,6 @@ class OpenAILLM(BaseLLM):
|
||||
"text": f"File content:\n\n{attachment['content']}",
|
||||
}
|
||||
)
|
||||
|
||||
return prepared_messages
|
||||
|
||||
def _get_base64_image(self, attachment):
|
||||
@@ -357,7 +352,6 @@ class OpenAILLM(BaseLLM):
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
raise ValueError("No file path provided in attachment")
|
||||
|
||||
try:
|
||||
with self.storage.get_file(file_path) as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
@@ -381,12 +375,10 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if "openai_file_id" in attachment:
|
||||
return attachment["openai_file_id"]
|
||||
|
||||
file_path = attachment.get("path")
|
||||
|
||||
if not self.storage.file_exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
try:
|
||||
file_id = self.storage.process_file(
|
||||
file_path,
|
||||
@@ -404,7 +396,6 @@ class OpenAILLM(BaseLLM):
|
||||
attachments_collection.update_one(
|
||||
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
|
||||
)
|
||||
|
||||
return file_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error uploading file to OpenAI: {e}", exc_info=True)
|
||||
|
||||
Reference in New Issue
Block a user