fix simple routing (#2261)

This commit is contained in:
Alex
2026-01-12 13:51:19 +02:00
committed by GitHub
parent 2246866a09
commit a29bfa7489
4 changed files with 57 additions and 18 deletions

View File

@@ -187,3 +187,18 @@ AZURE_OPENAI_MODELS = [
),
),
]
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
return AvailableModel(
id=model_name,
provider=ModelProvider.OPENAI,
display_name=model_name,
description=f"Custom OpenAI-compatible model at {base_url}",
base_url=base_url,
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
),
)

View File

@@ -84,7 +84,9 @@ class ModelRegistry:
self.models.clear()
self._add_docsgpt_models(settings)
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
if not settings.OPENAI_BASE_URL:
self._add_docsgpt_models(settings)
if settings.OPENAI_API_KEY or (
settings.LLM_PROVIDER == "openai" and settings.API_KEY
):
@@ -125,19 +127,34 @@ class ModelRegistry:
)
def _add_openai_models(self, settings):
from application.core.model_configs import OPENAI_MODELS
from application.core.model_configs import (
OPENAI_MODELS,
create_custom_openai_model,
)
# Add standard OpenAI models if API key is present
if settings.OPENAI_API_KEY:
for model in OPENAI_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "openai" and settings.LLM_NAME:
# Add custom model if OPENAI_BASE_URL is configured with a custom LLM_NAME
if (
settings.LLM_PROVIDER == "openai"
and settings.OPENAI_BASE_URL
and settings.LLM_NAME
):
custom_model = create_custom_openai_model(
settings.LLM_NAME, settings.OPENAI_BASE_URL
)
self.models[settings.LLM_NAME] = custom_model
logger.info(
f"Registered custom OpenAI model: {settings.LLM_NAME} at {settings.OPENAI_BASE_URL}"
)
# Fallback: add all OpenAI models if none were added
if not any(m.provider.value == "openai" for m in self.models.values()):
for model in OPENAI_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in OPENAI_MODELS:
self.models[model.id] = model
self.models[model.id] = model
def _add_azure_openai_models(self, settings):
from application.core.model_configs import AZURE_OPENAI_MODELS

View File

@@ -136,7 +136,7 @@ function PromptTextarea({
</div>
<textarea
id={id}
className="peer border-silver dark:border-silver/40 relative z-10 h-48 w-full resize-none rounded border-2 bg-transparent px-3 py-2 text-base text-gray-800 outline-none dark:bg-transparent dark:text-white"
className="peer border-silver dark:border-silver/40 relative z-10 h-48 w-full resize-none rounded border-2 bg-transparent px-3 py-2 text-base text-gray-800 outline-none md:h-64 lg:h-80 dark:bg-transparent dark:text-white"
value={value}
onChange={onChange}
onScroll={handleScroll}
@@ -765,7 +765,7 @@ export default function PromptsModal({
setNewPromptContent('');
}
}}
className="mx-4 mt-16 w-[95vw] max-w-[650px] rounded-2xl bg-white px-4 py-4 sm:px-6 sm:py-6 md:px-8 md:py-6 dark:bg-[#1E1E2A]"
className="mx-4 mt-16 w-[95vw] max-w-[650px] rounded-2xl bg-white px-4 py-4 sm:px-6 sm:py-6 md:max-w-[860px] md:px-8 md:py-6 lg:max-w-[980px] dark:bg-[#1E1E2A]"
contentClassName="!overflow-visible"
>
{view}

View File

@@ -64,10 +64,12 @@ def test_model_without_base_url():
def test_validate_model_id():
"""Test model_id validation"""
# Get the registry instance to check what models are available
ModelRegistry.get_instance()
registry = ModelRegistry.get_instance()
# Test with a model that should exist (docsgpt-local is always added)
assert validate_model_id("docsgpt-local") is True
# Test with a model that exists in the registry
available_models = registry.get_all_models()
if available_models:
assert validate_model_id(available_models[0].id) is True
# Test with invalid model_id
assert validate_model_id("invalid-model-xyz-123") is False
@@ -79,14 +81,19 @@ def test_validate_model_id():
@pytest.mark.unit
def test_get_base_url_for_model():
"""Test retrieving base_url for a model"""
# Test with a model that doesn't have base_url
result = get_base_url_for_model("docsgpt-local")
assert result is None # docsgpt-local doesn't have custom base_url
# Test with invalid model
result = get_base_url_for_model("invalid-model")
assert result is None
# Test with a model that exists but may or may not have base_url
registry = ModelRegistry.get_instance()
available_models = registry.get_all_models()
if available_models:
model = available_models[0]
result = get_base_url_for_model(model.id)
# Result should match the model's base_url (could be None or a string)
assert result == model.base_url
@pytest.mark.unit
def test_model_validation_error_message():