mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-04-26 11:25:45 +00:00
fix simple routing (#2261)
This commit is contained in:
@@ -187,3 +187,18 @@ AZURE_OPENAI_MODELS = [
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
|
||||
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
|
||||
return AvailableModel(
|
||||
id=model_name,
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name=model_name,
|
||||
description=f"Custom OpenAI-compatible model at {base_url}",
|
||||
base_url=base_url,
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -84,7 +84,9 @@ class ModelRegistry:
|
||||
|
||||
self.models.clear()
|
||||
|
||||
self._add_docsgpt_models(settings)
|
||||
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
|
||||
if not settings.OPENAI_BASE_URL:
|
||||
self._add_docsgpt_models(settings)
|
||||
if settings.OPENAI_API_KEY or (
|
||||
settings.LLM_PROVIDER == "openai" and settings.API_KEY
|
||||
):
|
||||
@@ -125,19 +127,34 @@ class ModelRegistry:
|
||||
)
|
||||
|
||||
def _add_openai_models(self, settings):
|
||||
from application.core.model_configs import OPENAI_MODELS
|
||||
from application.core.model_configs import (
|
||||
OPENAI_MODELS,
|
||||
create_custom_openai_model,
|
||||
)
|
||||
|
||||
# Add standard OpenAI models if API key is present
|
||||
if settings.OPENAI_API_KEY:
|
||||
for model in OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "openai" and settings.LLM_NAME:
|
||||
|
||||
# Add custom model if OPENAI_BASE_URL is configured with a custom LLM_NAME
|
||||
if (
|
||||
settings.LLM_PROVIDER == "openai"
|
||||
and settings.OPENAI_BASE_URL
|
||||
and settings.LLM_NAME
|
||||
):
|
||||
custom_model = create_custom_openai_model(
|
||||
settings.LLM_NAME, settings.OPENAI_BASE_URL
|
||||
)
|
||||
self.models[settings.LLM_NAME] = custom_model
|
||||
logger.info(
|
||||
f"Registered custom OpenAI model: {settings.LLM_NAME} at {settings.OPENAI_BASE_URL}"
|
||||
)
|
||||
|
||||
# Fallback: add all OpenAI models if none were added
|
||||
if not any(m.provider.value == "openai" for m in self.models.values()):
|
||||
for model in OPENAI_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_azure_openai_models(self, settings):
|
||||
from application.core.model_configs import AZURE_OPENAI_MODELS
|
||||
|
||||
@@ -136,7 +136,7 @@ function PromptTextarea({
|
||||
</div>
|
||||
<textarea
|
||||
id={id}
|
||||
className="peer border-silver dark:border-silver/40 relative z-10 h-48 w-full resize-none rounded border-2 bg-transparent px-3 py-2 text-base text-gray-800 outline-none dark:bg-transparent dark:text-white"
|
||||
className="peer border-silver dark:border-silver/40 relative z-10 h-48 w-full resize-none rounded border-2 bg-transparent px-3 py-2 text-base text-gray-800 outline-none md:h-64 lg:h-80 dark:bg-transparent dark:text-white"
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
onScroll={handleScroll}
|
||||
@@ -765,7 +765,7 @@ export default function PromptsModal({
|
||||
setNewPromptContent('');
|
||||
}
|
||||
}}
|
||||
className="mx-4 mt-16 w-[95vw] max-w-[650px] rounded-2xl bg-white px-4 py-4 sm:px-6 sm:py-6 md:px-8 md:py-6 dark:bg-[#1E1E2A]"
|
||||
className="mx-4 mt-16 w-[95vw] max-w-[650px] rounded-2xl bg-white px-4 py-4 sm:px-6 sm:py-6 md:max-w-[860px] md:px-8 md:py-6 lg:max-w-[980px] dark:bg-[#1E1E2A]"
|
||||
contentClassName="!overflow-visible"
|
||||
>
|
||||
{view}
|
||||
|
||||
@@ -64,10 +64,12 @@ def test_model_without_base_url():
|
||||
def test_validate_model_id():
|
||||
"""Test model_id validation"""
|
||||
# Get the registry instance to check what models are available
|
||||
ModelRegistry.get_instance()
|
||||
registry = ModelRegistry.get_instance()
|
||||
|
||||
# Test with a model that should exist (docsgpt-local is always added)
|
||||
assert validate_model_id("docsgpt-local") is True
|
||||
# Test with a model that exists in the registry
|
||||
available_models = registry.get_all_models()
|
||||
if available_models:
|
||||
assert validate_model_id(available_models[0].id) is True
|
||||
|
||||
# Test with invalid model_id
|
||||
assert validate_model_id("invalid-model-xyz-123") is False
|
||||
@@ -79,14 +81,19 @@ def test_validate_model_id():
|
||||
@pytest.mark.unit
|
||||
def test_get_base_url_for_model():
|
||||
"""Test retrieving base_url for a model"""
|
||||
# Test with a model that doesn't have base_url
|
||||
result = get_base_url_for_model("docsgpt-local")
|
||||
assert result is None # docsgpt-local doesn't have custom base_url
|
||||
|
||||
# Test with invalid model
|
||||
result = get_base_url_for_model("invalid-model")
|
||||
assert result is None
|
||||
|
||||
# Test with a model that exists but may or may not have base_url
|
||||
registry = ModelRegistry.get_instance()
|
||||
available_models = registry.get_all_models()
|
||||
if available_models:
|
||||
model = available_models[0]
|
||||
result = get_base_url_for_model(model.id)
|
||||
# Result should match the model's base_url (could be None or a string)
|
||||
assert result == model.base_url
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_model_validation_error_message():
|
||||
|
||||
Reference in New Issue
Block a user