feat: model registry and capabilities for multi-provider support (#2158)

* feat: Implement model registry and capabilities for multi-provider support

- Added ModelRegistry to manage available models and their capabilities.
- Introduced ModelProvider enum for different LLM providers.
- Created ModelCapabilities dataclass to define model features.
- Implemented methods to load models based on API keys and settings.
- Added utility functions for model management in model_utils.py.
- Updated settings.py to include provider-specific API keys.
- Refactored LLM classes (Anthropic, OpenAI, Google, etc.) to utilize new model registry.
- Enhanced utility functions to handle token limits and model validation.
- Improved code structure and logging for better maintainability.

* feat: Add model selection feature with API integration and UI component

* feat: Add model selection and default model functionality in agent management

* test: Update assertions and formatting in stream processing tests

* refactor(llm): Standardize model identifier to model_id

* fix tests

---------

Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
Siddhant Rai
2025-11-14 16:43:19 +05:30
committed by GitHub
parent fbf7cf874b
commit 3f7de867cc
54 changed files with 1388 additions and 226 deletions

View File

@@ -1,5 +1,8 @@
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
import logging
logger = logging.getLogger(__name__)
class AgentCreator:
@@ -13,4 +16,5 @@ class AgentCreator:
agent_class = cls.agents.get(type.lower())
if not agent_class:
raise ValueError(f"No agent class found for type {type}")
return agent_class(*args, **kwargs)

View File

@@ -21,7 +21,7 @@ class BaseAgent(ABC):
self,
endpoint: str,
llm_name: str,
gpt_model: str,
model_id: str,
api_key: str,
user_api_key: Optional[str] = None,
prompt: str = "",
@@ -37,7 +37,7 @@ class BaseAgent(ABC):
):
self.endpoint = endpoint
self.llm_name = llm_name
self.gpt_model = gpt_model
self.model_id = model_id
self.api_key = api_key
self.user_api_key = user_api_key
self.prompt = prompt
@@ -52,6 +52,7 @@ class BaseAgent(ABC):
api_key=api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
)
self.retrieved_docs = retrieved_docs or []
self.llm_handler = LLMHandlerCreator.create_handler(
@@ -316,7 +317,7 @@ class BaseAgent(ABC):
return messages
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
gen_kwargs = {"model": self.gpt_model, "messages": messages}
gen_kwargs = {"model": self.model_id, "messages": messages}
if (
hasattr(self.llm, "_supports_tools")

View File

@@ -86,7 +86,7 @@ class ReActAgent(BaseAgent):
messages = [{"role": "user", "content": plan_prompt}]
plan_stream = self.llm.gen_stream(
model=self.gpt_model,
model=self.model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
@@ -151,7 +151,7 @@ class ReActAgent(BaseAgent):
messages = [{"role": "user", "content": final_prompt}]
final_stream = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=None
model=self.model_id, messages=messages, tools=None
)
if log_context:

View File

@@ -54,6 +54,10 @@ class AnswerResource(Resource, BaseAnswerResource):
default=True,
description="Whether to save the conversation",
),
"model_id": fields.String(
required=False,
description="Model ID to use for this request",
),
"passthrough": fields.Raw(
required=False,
description="Dynamic parameters to inject into prompt template",
@@ -97,6 +101,7 @@ class AnswerResource(Resource, BaseAnswerResource):
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
model_id=processor.model_id,
)
stream_result = self.process_response_stream(stream)

View File

@@ -7,11 +7,16 @@ from flask import jsonify, make_response, Response
from flask_restx import Namespace
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
get_default_model_id,
get_provider_from_model_id,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import check_required_fields, get_gpt_model
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
@@ -27,7 +32,7 @@ class BaseAnswerResource:
db = mongo[settings.MONGO_DB_NAME]
self.db = db
self.user_logs_collection = db["user_logs"]
self.gpt_model = get_gpt_model()
self.default_model_id = get_default_model_id()
self.conversation_service = ConversationService()
def validate_request(
@@ -54,7 +59,6 @@ class BaseAnswerResource:
api_key = agent_config.get("user_api_key")
if not api_key:
return None
agents_collection = self.db["agents"]
agent = agents_collection.find_one({"key": api_key})
@@ -62,7 +66,6 @@ class BaseAnswerResource:
return make_response(
jsonify({"success": False, "message": "Invalid API key."}), 401
)
limited_token_mode_raw = agent.get("limited_token_mode", False)
limited_request_mode_raw = agent.get("limited_request_mode", False)
@@ -110,15 +113,12 @@ class BaseAnswerResource:
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
else:
daily_token_usage = 0
if limited_request_mode:
daily_request_usage = token_usage_collection.count_documents(match_query)
else:
daily_request_usage = 0
if not limited_token_mode and not limited_request_mode:
return None
token_exceeded = (
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
)
@@ -138,7 +138,6 @@ class BaseAnswerResource:
),
429,
)
return None
def complete_stream(
@@ -155,6 +154,7 @@ class BaseAnswerResource:
agent_id: Optional[str] = None,
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
model_id: Optional[str] = None,
) -> Generator[str, None, None]:
"""
Generator function that streams the complete conversation response.
@@ -173,6 +173,7 @@ class BaseAnswerResource:
agent_id: ID of agent used
is_shared_usage: Flag for shared agent usage
shared_token: Token for shared agent
model_id: Model ID used for the request
retrieved_docs: Pre-fetched documents for sources (optional)
Yields:
@@ -220,7 +221,6 @@ class BaseAnswerResource:
elif "type" in line:
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
structured_data = {
"type": "structured_answer",
@@ -230,15 +230,22 @@ class BaseAnswerResource:
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
provider = (
get_provider_from_model_id(model_id)
if model_id
else settings.LLM_PROVIDER
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
provider or settings.LLM_PROVIDER,
api_key=system_api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
)
if should_save_conversation:
@@ -250,7 +257,7 @@ class BaseAnswerResource:
source_log_docs,
tool_calls,
llm,
self.gpt_model,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
@@ -280,12 +287,11 @@ class BaseAnswerResource:
log_data["structured_output"] = True
if schema_info:
log_data["schema"] = schema_info
# Clean up text fields to be no longer than 10000 characters
for key, value in log_data.items():
if isinstance(value, str) and len(value) > 10000:
log_data[key] = value[:10000]
self.user_logs_collection.insert_one(log_data)
data = json.dumps({"type": "end"})
@@ -293,6 +299,7 @@ class BaseAnswerResource:
except GeneratorExit:
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
# Save partial response
if should_save_conversation and response_full:
try:
if isNoneDoc:
@@ -312,7 +319,7 @@ class BaseAnswerResource:
source_log_docs,
tool_calls,
llm,
self.gpt_model,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
@@ -369,7 +376,7 @@ class BaseAnswerResource:
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return None, None, None, None, event["error"]
return None, None, None, None, event["error"], None
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
@@ -377,8 +384,7 @@ class BaseAnswerResource:
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return None, None, None, None, "Stream ended unexpectedly"
return None, None, None, None, "Stream ended unexpectedly", None
result = (
conversation_id,
response_full,
@@ -390,7 +396,6 @@ class BaseAnswerResource:
if is_structured:
result = result + ({"structured": True, "schema": schema_info},)
return result
def error_stream_generate(self, err_response):

View File

@@ -57,6 +57,10 @@ class StreamResource(Resource, BaseAnswerResource):
default=True,
description="Whether to save the conversation",
),
"model_id": fields.String(
required=False,
description="Model ID to use for this request",
),
"attachments": fields.List(
fields.String, required=False, description="List of attachment IDs"
),
@@ -101,6 +105,7 @@ class StreamResource(Resource, BaseAnswerResource):
agent_id=data.get("agent_id"),
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
),
mimetype="text/event-stream",
)

View File

@@ -52,7 +52,7 @@ class ConversationService:
sources: List[Dict[str, Any]],
tool_calls: List[Dict[str, Any]],
llm: Any,
gpt_model: str,
model_id: str,
decoded_token: Dict[str, Any],
index: Optional[int] = None,
api_key: Optional[str] = None,
@@ -66,7 +66,7 @@ class ConversationService:
if not user_id:
raise ValueError("User ID not found in token")
current_time = datetime.now(timezone.utc)
# clean up in sources array such that we save max 1k characters for text part
for source in sources:
if "text" in source and isinstance(source["text"], str):
@@ -90,6 +90,7 @@ class ConversationService:
f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time,
f"queries.{index}.attachments": attachment_ids,
f"queries.{index}.model_id": model_id,
}
},
)
@@ -120,6 +121,7 @@ class ConversationService:
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
}
},
@@ -146,7 +148,7 @@ class ConversationService:
]
completion = llm.gen(
model=gpt_model, messages=messages_summary, max_tokens=30
model=model_id, messages=messages_summary, max_tokens=30
)
conversation_data = {
@@ -162,6 +164,7 @@ class ConversationService:
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
],
}

View File

@@ -12,12 +12,17 @@ from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.conversation_service import ConversationService
from application.api.answer.services.prompt_renderer import PromptRenderer
from application.core.model_utils import (
get_api_key_for_provider,
get_default_model_id,
get_provider_from_model_id,
validate_model_id,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import (
calculate_doc_token_budget,
get_gpt_model,
limit_chat_history,
)
@@ -83,7 +88,7 @@ class StreamProcessor:
self.retriever_config = {}
self.is_shared_usage = False
self.shared_token = None
self.gpt_model = get_gpt_model()
self.model_id: Optional[str] = None
self.conversation_service = ConversationService()
self.prompt_renderer = PromptRenderer()
self._prompt_content: Optional[str] = None
@@ -91,6 +96,7 @@ class StreamProcessor:
def initialize(self):
"""Initialize all required components for processing"""
self._validate_and_set_model()
self._configure_agent()
self._configure_source()
self._configure_retriever()
@@ -112,7 +118,7 @@ class StreamProcessor:
]
else:
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
json.loads(self.data.get("history", "[]")), model_id=self.model_id
)
def _process_attachments(self):
@@ -143,6 +149,25 @@ class StreamProcessor:
)
return attachments
def _validate_and_set_model(self):
"""Validate and set model_id from request"""
from application.core.model_settings import ModelRegistry
requested_model = self.data.get("model_id")
if requested_model:
if not validate_model_id(requested_model):
registry = ModelRegistry.get_instance()
available_models = [m.id for m in registry.get_enabled_models()]
raise ValueError(
f"Invalid model_id '{requested_model}'. "
f"Available models: {', '.join(available_models[:5])}"
+ (f" and {len(available_models) - 5} more" if len(available_models) > 5 else "")
)
self.model_id = requested_model
else:
self.model_id = get_default_model_id()
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control"""
if not agent_id:
@@ -322,7 +347,7 @@ class StreamProcessor:
def _configure_retriever(self):
history_token_limit = int(self.data.get("token_limit", 2000))
doc_token_limit = calculate_doc_token_budget(
gpt_model=self.gpt_model, history_token_limit=history_token_limit
model_id=self.model_id, history_token_limit=history_token_limit
)
self.retriever_config = {
@@ -344,7 +369,7 @@ class StreamProcessor:
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
chunks=self.retriever_config["chunks"],
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
gpt_model=self.gpt_model,
model_id=self.model_id,
user_api_key=self.agent_config["user_api_key"],
decoded_token=self.decoded_token,
)
@@ -626,12 +651,19 @@ class StreamProcessor:
tools_data=tools_data,
)
provider = (
get_provider_from_model_id(self.model_id)
if self.model_id
else settings.LLM_PROVIDER
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
return AgentCreator.create_agent(
self.agent_config["agent_type"],
endpoint="stream",
llm_name=settings.LLM_PROVIDER,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
llm_name=provider or settings.LLM_PROVIDER,
model_id=self.model_id,
api_key=system_api_key,
user_api_key=self.agent_config["user_api_key"],
prompt=rendered_prompt,
chat_history=self.history,

View File

@@ -95,6 +95,8 @@ class GetAgent(Resource):
"shared": agent.get("shared_publicly", False),
"shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""),
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
}
return make_response(jsonify(data), 200)
except Exception as e:
@@ -172,6 +174,8 @@ class GetAgents(Resource):
"shared": agent.get("shared_publicly", False),
"shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""),
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
}
for agent in agents
if "source" in agent or "retriever" in agent
@@ -230,6 +234,14 @@ class CreateAgent(Resource):
required=False,
description="Request limit for the agent in limited mode",
),
"models": fields.List(
fields.String,
required=False,
description="List of available model IDs for this agent",
),
"default_model_id": fields.String(
required=False, description="Default model ID for this agent"
),
},
)
@@ -258,6 +270,11 @@ class CreateAgent(Resource):
data["json_schema"] = json.loads(data["json_schema"])
except json.JSONDecodeError:
data["json_schema"] = None
if "models" in data:
try:
data["models"] = json.loads(data["models"])
except json.JSONDecodeError:
data["models"] = []
print(f"Received data: {data}")
# Validate JSON schema if provided
@@ -399,6 +416,8 @@ class CreateAgent(Resource):
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
"lastUsedAt": None,
"key": key,
"models": data.get("models", []),
"default_model_id": data.get("default_model_id", ""),
}
if new_agent["chunks"] == "":
new_agent["chunks"] = "2"
@@ -464,6 +483,14 @@ class UpdateAgent(Resource):
required=False,
description="Request limit for the agent in limited mode",
),
"models": fields.List(
fields.String,
required=False,
description="List of available model IDs for this agent",
),
"default_model_id": fields.String(
required=False, description="Default model ID for this agent"
),
},
)
@@ -487,7 +514,7 @@ class UpdateAgent(Resource):
data = request.get_json()
else:
data = request.form.to_dict()
json_fields = ["tools", "sources", "json_schema"]
json_fields = ["tools", "sources", "json_schema", "models"]
for field in json_fields:
if field in data and data[field]:
try:
@@ -555,6 +582,8 @@ class UpdateAgent(Resource):
"token_limit",
"limited_request_mode",
"request_limit",
"models",
"default_model_id",
]
for field in allowed_fields:

View File

@@ -0,0 +1,3 @@
from .routes import models_ns
__all__ = ["models_ns"]

View File

@@ -0,0 +1,25 @@
from flask import current_app, jsonify, make_response
from flask_restx import Namespace, Resource
from application.core.model_settings import ModelRegistry
models_ns = Namespace("models", description="Available models", path="/api")
@models_ns.route("/models")
class ModelsListResource(Resource):
def get(self):
"""Get list of available models with their capabilities."""
try:
registry = ModelRegistry.get_instance()
models = registry.get_enabled_models()
response = {
"models": [model.to_dict() for model in models],
"default_model_id": registry.default_model_id,
"count": len(models),
}
except Exception as err:
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
return make_response(jsonify(response), 200)

View File

@@ -10,6 +10,7 @@ from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns
from .analytics import analytics_ns
from .attachments import attachments_ns
from .conversations import conversations_ns
from .models import models_ns
from .prompts import prompts_ns
from .sharing import sharing_ns
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
@@ -27,6 +28,9 @@ api.add_namespace(attachments_ns)
# Conversations
api.add_namespace(conversations_ns)
# Models
api.add_namespace(models_ns)
# Agents (main, sharing, webhooks)
api.add_namespace(agents_ns)
api.add_namespace(agents_sharing_ns)

View File

@@ -0,0 +1,223 @@
"""
Model configurations for all supported LLM providers.
"""
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
)
OPENAI_ATTACHMENTS = [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
GOOGLE_ATTACHMENTS = [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
OPENAI_MODELS = [
AvailableModel(
id="gpt-4o",
provider=ModelProvider.OPENAI,
display_name="GPT-4 Omni",
description="Latest and most capable model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=128000,
),
),
AvailableModel(
id="gpt-4o-mini",
provider=ModelProvider.OPENAI,
display_name="GPT-4 Omni Mini",
description="Fast and efficient",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=128000,
),
),
AvailableModel(
id="gpt-4-turbo",
provider=ModelProvider.OPENAI,
display_name="GPT-4 Turbo",
description="Fast GPT-4 with 128k context",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=128000,
),
),
AvailableModel(
id="gpt-4",
provider=ModelProvider.OPENAI,
display_name="GPT-4",
description="Most capable model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=8192,
),
),
AvailableModel(
id="gpt-3.5-turbo",
provider=ModelProvider.OPENAI,
display_name="GPT-3.5 Turbo",
description="Fast and cost-effective",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=4096,
),
),
]
ANTHROPIC_MODELS = [
AvailableModel(
id="claude-3-5-sonnet-20241022",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3.5 Sonnet (Latest)",
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=200000,
),
),
AvailableModel(
id="claude-3-5-sonnet",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3.5 Sonnet",
description="Balanced performance and capability",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=200000,
),
),
AvailableModel(
id="claude-3-opus",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3 Opus",
description="Most capable Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=200000,
),
),
AvailableModel(
id="claude-3-haiku",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3 Haiku",
description="Fastest Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=200000,
),
),
]
GOOGLE_MODELS = [
AvailableModel(
id="gemini-flash-latest",
provider=ModelProvider.GOOGLE,
display_name="Gemini Flash (Latest)",
description="Latest experimental Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=int(1e6),
),
),
AvailableModel(
id="gemini-flash-lite-latest",
provider=ModelProvider.GOOGLE,
display_name="Gemini Flash Lite (Latest)",
description="Fast with huge context window",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=int(1e6),
),
),
AvailableModel(
id="gemini-2.5-pro",
provider=ModelProvider.GOOGLE,
display_name="Gemini 2.5 Pro",
description="Most capable Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=2000000,
),
),
]
GROQ_MODELS = [
AvailableModel(
id="llama-3.3-70b-versatile",
provider=ModelProvider.GROQ,
display_name="Llama 3.3 70B",
description="Latest Llama model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
AvailableModel(
id="llama-3.1-8b-instant",
provider=ModelProvider.GROQ,
display_name="Llama 3.1 8B",
description="Ultra-fast inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
AvailableModel(
id="mixtral-8x7b-32768",
provider=ModelProvider.GROQ,
display_name="Mixtral 8x7B",
description="High-speed inference with tools",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=32768,
),
),
]
AZURE_OPENAI_MODELS = [
AvailableModel(
id="azure-gpt-4",
provider=ModelProvider.AZURE_OPENAI,
display_name="Azure OpenAI GPT-4",
description="Azure-hosted GPT model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=8192,
),
),
]

View File

@@ -0,0 +1,236 @@
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
class ModelProvider(str, Enum):
OPENAI = "openai"
AZURE_OPENAI = "azure_openai"
ANTHROPIC = "anthropic"
GROQ = "groq"
GOOGLE = "google"
HUGGINGFACE = "huggingface"
LLAMA_CPP = "llama.cpp"
DOCSGPT = "docsgpt"
PREMAI = "premai"
SAGEMAKER = "sagemaker"
NOVITA = "novita"
@dataclass
class ModelCapabilities:
supports_tools: bool = False
supports_structured_output: bool = False
supports_streaming: bool = True
supported_attachment_types: List[str] = field(default_factory=list)
context_window: int = 128000
input_cost_per_token: Optional[float] = None
output_cost_per_token: Optional[float] = None
@dataclass
class AvailableModel:
id: str
provider: ModelProvider
display_name: str
description: str = ""
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
enabled: bool = True
base_url: Optional[str] = None
def to_dict(self) -> Dict:
result = {
"id": self.id,
"provider": self.provider.value,
"display_name": self.display_name,
"description": self.description,
"supported_attachment_types": self.capabilities.supported_attachment_types,
"supports_tools": self.capabilities.supports_tools,
"supports_structured_output": self.capabilities.supports_structured_output,
"supports_streaming": self.capabilities.supports_streaming,
"context_window": self.capabilities.context_window,
"enabled": self.enabled,
}
if self.base_url:
result["base_url"] = self.base_url
return result
class ModelRegistry:
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not ModelRegistry._initialized:
self.models: Dict[str, AvailableModel] = {}
self.default_model_id: Optional[str] = None
self._load_models()
ModelRegistry._initialized = True
@classmethod
def get_instance(cls) -> "ModelRegistry":
return cls()
def _load_models(self):
from application.core.settings import settings
self.models.clear()
self._add_docsgpt_models(settings)
if settings.OPENAI_API_KEY or (
settings.LLM_PROVIDER == "openai" and settings.API_KEY
):
self._add_openai_models(settings)
if settings.OPENAI_API_BASE or (
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
):
self._add_azure_openai_models(settings)
if settings.ANTHROPIC_API_KEY or (
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
):
self._add_anthropic_models(settings)
if settings.GOOGLE_API_KEY or (
settings.LLM_PROVIDER == "google" and settings.API_KEY
):
self._add_google_models(settings)
if settings.GROQ_API_KEY or (
settings.LLM_PROVIDER == "groq" and settings.API_KEY
):
self._add_groq_models(settings)
if settings.HUGGINGFACE_API_KEY or (
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
):
self._add_huggingface_models(settings)
# Default model selection
if settings.LLM_NAME and settings.LLM_NAME in self.models:
self.default_model_id = settings.LLM_NAME
elif settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
self.default_model_id = model_id
break
else:
self.default_model_id = next(iter(self.models.keys()))
logger.info(
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
)
def _add_openai_models(self, settings):
from application.core.model_configs import OPENAI_MODELS
if settings.OPENAI_API_KEY:
for model in OPENAI_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "openai" and settings.LLM_NAME:
for model in OPENAI_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in OPENAI_MODELS:
self.models[model.id] = model
def _add_azure_openai_models(self, settings):
from application.core.model_configs import AZURE_OPENAI_MODELS
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
for model in AZURE_OPENAI_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in AZURE_OPENAI_MODELS:
self.models[model.id] = model
def _add_anthropic_models(self, settings):
from application.core.model_configs import ANTHROPIC_MODELS
if settings.ANTHROPIC_API_KEY:
for model in ANTHROPIC_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
for model in ANTHROPIC_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in ANTHROPIC_MODELS:
self.models[model.id] = model
def _add_google_models(self, settings):
from application.core.model_configs import GOOGLE_MODELS
if settings.GOOGLE_API_KEY:
for model in GOOGLE_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
for model in GOOGLE_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in GOOGLE_MODELS:
self.models[model.id] = model
def _add_groq_models(self, settings):
from application.core.model_configs import GROQ_MODELS
if settings.GROQ_API_KEY:
for model in GROQ_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
for model in GROQ_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in GROQ_MODELS:
self.models[model.id] = model
def _add_docsgpt_models(self, settings):
model_id = "docsgpt-local"
model = AvailableModel(
id=model_id,
provider=ModelProvider.DOCSGPT,
display_name="DocsGPT Model",
description="Local model",
capabilities=ModelCapabilities(
supports_tools=False,
supported_attachment_types=[],
),
)
self.models[model_id] = model
def _add_huggingface_models(self, settings):
model_id = "huggingface-local"
model = AvailableModel(
id=model_id,
provider=ModelProvider.HUGGINGFACE,
display_name="Hugging Face Model",
description="Local Hugging Face model",
capabilities=ModelCapabilities(
supports_tools=False,
supported_attachment_types=[],
),
)
self.models[model_id] = model
def get_model(self, model_id: str) -> Optional[AvailableModel]:
return self.models.get(model_id)
def get_all_models(self) -> List[AvailableModel]:
return list(self.models.values())
def get_enabled_models(self) -> List[AvailableModel]:
return [m for m in self.models.values() if m.enabled]
def model_exists(self, model_id: str) -> bool:
return model_id in self.models

View File

@@ -0,0 +1,91 @@
from typing import Any, Dict, Optional
from application.core.model_settings import ModelRegistry
def get_api_key_for_provider(provider: str) -> Optional[str]:
"""Get the appropriate API key for a provider"""
from application.core.settings import settings
provider_key_map = {
"openai": settings.OPENAI_API_KEY,
"anthropic": settings.ANTHROPIC_API_KEY,
"google": settings.GOOGLE_API_KEY,
"groq": settings.GROQ_API_KEY,
"huggingface": settings.HUGGINGFACE_API_KEY,
"azure_openai": settings.API_KEY,
"docsgpt": None,
"llama.cpp": None,
}
provider_key = provider_key_map.get(provider)
if provider_key:
return provider_key
return settings.API_KEY
def get_all_available_models() -> Dict[str, Dict[str, Any]]:
"""Get all available models with metadata for API response"""
registry = ModelRegistry.get_instance()
return {model.id: model.to_dict() for model in registry.get_enabled_models()}
def validate_model_id(model_id: str) -> bool:
"""Check if a model ID exists in registry"""
registry = ModelRegistry.get_instance()
return registry.model_exists(model_id)
def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]:
"""Get capabilities for a specific model"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model:
return {
"supported_attachment_types": model.capabilities.supported_attachment_types,
"supports_tools": model.capabilities.supports_tools,
"supports_structured_output": model.capabilities.supports_structured_output,
"context_window": model.capabilities.context_window,
}
return None
def get_default_model_id() -> str:
"""Get the system default model ID"""
registry = ModelRegistry.get_instance()
return registry.default_model_id
def get_provider_from_model_id(model_id: str) -> Optional[str]:
"""Get the provider name for a given model_id"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model:
return model.provider.value
return None
def get_token_limit(model_id: str) -> int:
"""
Get context window (token limit) for a model.
Returns model's context_window or default 128000 if model not found.
"""
from application.core.settings import settings
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model:
return model.capabilities.context_window
return settings.DEFAULT_LLM_TOKEN_LIMIT
def get_base_url_for_model(model_id: str) -> Optional[str]:
"""
Get the custom base_url for a specific model if configured.
Returns None if no custom base_url is set.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model:
return model.base_url
return None

View File

@@ -22,15 +22,7 @@ class Settings(BaseSettings):
MONGO_DB_NAME: str = "docsgpt"
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
DEFAULT_MAX_HISTORY: int = 150
LLM_TOKEN_LIMITS: dict = {
"gpt-4o": 128000,
"gpt-4o-mini": 128000,
"gpt-4": 8192,
"gpt-3.5-turbo": 4096,
"claude-2": int(1e5),
"gemini-2.5-flash": int(1e6),
}
DEFAULT_LLM_TOKEN_LIMIT: int = 128000
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
RESERVED_TOKENS: dict = {
"system_prompt": 500,
"current_query": 500,
@@ -64,14 +56,22 @@ class Settings(BaseSettings):
)
# GitHub source
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
# LLM Cache
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
API_URL: str = "http://localhost:7091" # backend url for celery worker
API_KEY: Optional[str] = None # LLM api key
API_KEY: Optional[str] = None # LLM api key (used by LLM_PROVIDER)
# Provider-specific API keys (for multi-model support)
OPENAI_API_KEY: Optional[str] = None
ANTHROPIC_API_KEY: Optional[str] = None
GOOGLE_API_KEY: Optional[str] = None
GROQ_API_KEY: Optional[str] = None
HUGGINGFACE_API_KEY: Optional[str] = None
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
@@ -138,11 +138,12 @@ class Settings(BaseSettings):
# Encryption settings
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
ELEVENLABS_API_KEY: Optional[str] = None
# Tool pre-fetch settings
ENABLE_TOOL_PREFETCH: bool = True
path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")

View File

@@ -1,30 +1,41 @@
from application.llm.base import BaseLLM
from anthropic import AI_PROMPT, Anthropic, HUMAN_PROMPT
from application.core.settings import settings
from application.llm.base import BaseLLM
class AnthropicLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = (
api_key or settings.ANTHROPIC_API_KEY
) # If not provided, use a default from settings
self.api_key = api_key or settings.ANTHROPIC_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
self.anthropic = Anthropic(api_key=self.api_key)
# Use custom base_url if provided
if base_url:
self.anthropic = Anthropic(api_key=self.api_key, base_url=base_url)
else:
self.anthropic = Anthropic(api_key=self.api_key)
self.HUMAN_PROMPT = HUMAN_PROMPT
self.AI_PROMPT = AI_PROMPT
def _raw_gen(
self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs
self,
baseself,
model,
messages,
stream=False,
tools=None,
max_tokens=300,
**kwargs,
):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
if stream:
return self.gen_stream(model, prompt, stream, max_tokens, **kwargs)
completion = self.anthropic.completions.create(
model=model,
max_tokens_to_sample=max_tokens,
@@ -34,7 +45,14 @@ class AnthropicLLM(BaseLLM):
return completion.completion
def _raw_gen_stream(
self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs
self,
baseself,
model,
messages,
stream=True,
tools=None,
max_tokens=300,
**kwargs,
):
context = messages[0]["content"]
user_question = messages[-1]["content"]
@@ -50,5 +68,5 @@ class AnthropicLLM(BaseLLM):
for completion in stream_response:
yield completion.completion
finally:
if hasattr(stream_response, 'close'):
if hasattr(stream_response, "close"):
stream_response.close()

View File

@@ -13,30 +13,32 @@ class BaseLLM(ABC):
def __init__(
self,
decoded_token=None,
model_id=None,
base_url=None,
):
self.decoded_token = decoded_token
self.model_id = model_id
self.base_url = base_url
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
self.fallback_provider = settings.FALLBACK_LLM_PROVIDER
self.fallback_model_name = settings.FALLBACK_LLM_NAME
self.fallback_llm_api_key = settings.FALLBACK_LLM_API_KEY
self._fallback_llm = None
self._fallback_sequence_index = 0
@property
def fallback_llm(self):
"""Lazy-loaded fallback LLM instance."""
if (
self._fallback_llm is None
and self.fallback_provider
and self.fallback_model_name
):
"""Lazy-loaded fallback LLM from FALLBACK_* settings."""
if self._fallback_llm is None and settings.FALLBACK_LLM_PROVIDER:
try:
from application.llm.llm_creator import LLMCreator
self._fallback_llm = LLMCreator.create_llm(
self.fallback_provider,
self.fallback_llm_api_key,
None,
self.decoded_token,
settings.FALLBACK_LLM_PROVIDER,
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
user_api_key=None,
decoded_token=self.decoded_token,
model_id=settings.FALLBACK_LLM_NAME,
)
logger.info(
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
)
except Exception as e:
logger.error(
@@ -54,7 +56,7 @@ class BaseLLM(ABC):
self, method_name: str, decorators: list, *args, **kwargs
):
"""
Unified method execution with fallback support.
Execute method with fallback support.
Args:
method_name: Name of the raw method ('_raw_gen' or '_raw_gen_stream')
@@ -73,10 +75,10 @@ class BaseLLM(ABC):
return decorated_method()
except Exception as e:
if not self.fallback_llm:
logger.error(f"Primary LLM failed and no fallback available: {str(e)}")
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
raise
logger.warning(
f"Falling back to {self.fallback_provider}/{self.fallback_model_name}. Error: {str(e)}"
f"Primary LLM failed. Falling back to {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}. Error: {str(e)}"
)
fallback_method = getattr(

View File

@@ -1,5 +1,7 @@
import json
from openai import OpenAI
from application.core.settings import settings
from application.llm.base import BaseLLM
@@ -7,12 +9,11 @@ from application.llm.base import BaseLLM
class DocsGPTAPILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from openai import OpenAI
super().__init__(*args, **kwargs)
self.client = OpenAI(api_key="sk-docsgpt-public", base_url="https://oai.arc53.com")
self.api_key = "sk-docsgpt-public"
self.client = OpenAI(api_key=self.api_key, base_url="https://oai.arc53.com")
self.user_api_key = user_api_key
self.api_key = api_key
def _clean_messages_openai(self, messages):
cleaned_messages = []
@@ -22,7 +23,6 @@ class DocsGPTAPILLM(BaseLLM):
if role == "model":
role = "assistant"
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
@@ -69,7 +69,6 @@ class DocsGPTAPILLM(BaseLLM):
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages
def _raw_gen(
@@ -121,7 +120,6 @@ class DocsGPTAPILLM(BaseLLM):
response = self.client.chat.completions.create(
model="docsgpt", messages=messages, stream=stream, **kwargs
)
try:
for line in response:
if (
@@ -133,8 +131,8 @@ class DocsGPTAPILLM(BaseLLM):
elif len(line.choices) > 0:
yield line.choices[0]
finally:
if hasattr(response, 'close'):
if hasattr(response, "close"):
response.close()
def _supports_tools(self):
return True
return True

View File

@@ -13,8 +13,9 @@ from application.storage.storage_creator import StorageCreator
class GoogleLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = api_key
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
self.client = genai.Client(api_key=self.api_key)
self.storage = StorageCreator.get_storage()
@@ -47,21 +48,19 @@ class GoogleLLM(BaseLLM):
"""
if not attachments:
return messages
prepared_messages = messages.copy()
# Find the user message to attach files to the last one
user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user":
user_message_index = i
break
if user_message_index is None:
user_message = {"role": "user", "content": []}
prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1
if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [
@@ -69,7 +68,6 @@ class GoogleLLM(BaseLLM):
]
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = []
files = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
@@ -92,11 +90,9 @@ class GoogleLLM(BaseLLM):
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
}
)
if files:
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
prepared_messages[user_message_index]["content"].append({"files": files})
return prepared_messages
def _upload_file_to_google(self, attachment):
@@ -111,14 +107,11 @@ class GoogleLLM(BaseLLM):
"""
if "google_file_uri" in attachment:
return attachment["google_file_uri"]
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
try:
file_uri = self.storage.process_file(
file_path,
@@ -136,7 +129,6 @@ class GoogleLLM(BaseLLM):
attachments_collection.update_one(
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
)
return file_uri
except Exception as e:
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
@@ -153,7 +145,6 @@ class GoogleLLM(BaseLLM):
role = "model"
elif role == "tool":
role = "model"
parts = []
if role and content is not None:
if isinstance(content, str):
@@ -164,6 +155,7 @@ class GoogleLLM(BaseLLM):
parts.append(types.Part.from_text(text=item["text"]))
elif "function_call" in item:
# Remove null values from args to avoid API errors
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
@@ -194,10 +186,8 @@ class GoogleLLM(BaseLLM):
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
return cleaned_messages
def _clean_schema(self, schema_obj):
@@ -233,8 +223,8 @@ class GoogleLLM(BaseLLM):
cleaned[key] = [self._clean_schema(item) for item in value]
else:
cleaned[key] = value
# Validate that required properties actually exist in properties
if "required" in cleaned and "properties" in cleaned:
valid_required = []
properties_keys = set(cleaned["properties"].keys())
@@ -247,7 +237,6 @@ class GoogleLLM(BaseLLM):
cleaned.pop("required", None)
elif "required" in cleaned and "properties" not in cleaned:
cleaned.pop("required", None)
return cleaned
def _clean_tools_format(self, tools_list):
@@ -263,7 +252,6 @@ class GoogleLLM(BaseLLM):
cleaned_properties = {}
for k, v in properties.items():
cleaned_properties[k] = self._clean_schema(v)
genai_function = dict(
name=function["name"],
description=function["description"],
@@ -282,10 +270,8 @@ class GoogleLLM(BaseLLM):
name=function["name"],
description=function["description"],
)
genai_tool = types.Tool(function_declarations=[genai_function])
genai_tools.append(genai_tool)
return genai_tools
def _raw_gen(
@@ -307,16 +293,14 @@ class GoogleLLM(BaseLLM):
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
# Add response schema for structured output if provided
if response_schema:
config.response_schema = response_schema
config.response_mime_type = "application/json"
response = client.models.generate_content(
model=model,
contents=messages,
@@ -347,17 +331,16 @@ class GoogleLLM(BaseLLM):
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
# Add response schema for structured output if provided
if response_schema:
config.response_schema = response_schema
config.response_mime_type = "application/json"
# Check if we have both tools and file attachments
has_attachments = False
for message in messages:
for part in message.parts:
@@ -366,7 +349,6 @@ class GoogleLLM(BaseLLM):
break
if has_attachments:
break
logging.info(
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
)
@@ -405,7 +387,6 @@ class GoogleLLM(BaseLLM):
"""Convert JSON schema to Google AI structured output format."""
if not json_schema:
return None
type_map = {
"object": "OBJECT",
"array": "ARRAY",
@@ -418,12 +399,10 @@ class GoogleLLM(BaseLLM):
def convert(schema):
if not isinstance(schema, dict):
return schema
result = {}
schema_type = schema.get("type")
if schema_type:
result["type"] = type_map.get(schema_type.lower(), schema_type.upper())
for key in [
"description",
"nullable",
@@ -435,7 +414,6 @@ class GoogleLLM(BaseLLM):
]:
if key in schema:
result[key] = schema[key]
if "format" in schema:
format_value = schema["format"]
if schema_type == "string":
@@ -445,21 +423,17 @@ class GoogleLLM(BaseLLM):
result["format"] = format_value
else:
result["format"] = format_value
if "properties" in schema:
result["properties"] = {
k: convert(v) for k, v in schema["properties"].items()
}
if "propertyOrdering" not in result and result.get("type") == "OBJECT":
result["propertyOrdering"] = list(result["properties"].keys())
if "items" in schema:
result["items"] = convert(schema["items"])
for field in ["anyOf", "oneOf", "allOf"]:
if field in schema:
result[field] = [convert(s) for s in schema[field]]
return result
try:

View File

@@ -1,13 +1,18 @@
from application.llm.base import BaseLLM
from openai import OpenAI
from application.core.settings import settings
from application.llm.base import BaseLLM
class GroqLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1")
self.api_key = api_key
self.api_key = api_key or settings.GROQ_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
self.client = OpenAI(
api_key=self.api_key, base_url="https://api.groq.com/openai/v1"
)
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
if tools:

View File

@@ -282,7 +282,7 @@ class LLMHandler(ABC):
messages = e.value
break
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
model=agent.model_id, messages=messages, tools=agent.tools
)
parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm))
@@ -337,7 +337,7 @@ class LLMHandler(ABC):
tool_calls = {}
response = agent.llm.gen_stream(
model=agent.gpt_model, messages=messages, tools=agent.tools
model=agent.model_id, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))

View File

@@ -1,13 +1,17 @@
from application.llm.groq import GroqLLM
from application.llm.openai import OpenAILLM, AzureOpenAILLM
from application.llm.sagemaker import SagemakerAPILLM
from application.llm.huggingface import HuggingFaceLLM
from application.llm.llama_cpp import LlamaCpp
import logging
from application.llm.anthropic import AnthropicLLM
from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.premai import PremAILLM
from application.llm.google_ai import GoogleLLM
from application.llm.groq import GroqLLM
from application.llm.huggingface import HuggingFaceLLM
from application.llm.llama_cpp import LlamaCpp
from application.llm.novita import NovitaLLM
from application.llm.openai import AzureOpenAILLM, OpenAILLM
from application.llm.premai import PremAILLM
from application.llm.sagemaker import SagemakerAPILLM
logger = logging.getLogger(__name__)
class LLMCreator:
@@ -26,10 +30,26 @@ class LLMCreator:
}
@classmethod
def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs):
def create_llm(
cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs
):
from application.core.model_utils import get_base_url_for_model
llm_class = cls.llms.get(type.lower())
if not llm_class:
raise ValueError(f"No LLM class found for type {type}")
# Extract base_url from model configuration if model_id is provided
base_url = None
if model_id:
base_url = get_base_url_for_model(model_id)
return llm_class(
api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs
api_key,
user_api_key,
decoded_token=decoded_token,
model_id=model_id,
base_url=base_url,
*args,
**kwargs,
)

View File

@@ -2,6 +2,8 @@ import base64
import json
import logging
from openai import OpenAI
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.storage.storage_creator import StorageCreator
@@ -9,20 +11,25 @@ from application.storage.storage_creator import StorageCreator
class OpenAILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from openai import OpenAI
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(*args, **kwargs)
if (
self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
# Priority: 1) Parameter base_url, 2) Settings OPENAI_BASE_URL, 3) Default
effective_base_url = None
if base_url and isinstance(base_url, str) and base_url.strip():
effective_base_url = base_url
elif (
isinstance(settings.OPENAI_BASE_URL, str)
and settings.OPENAI_BASE_URL.strip()
):
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
effective_base_url = settings.OPENAI_BASE_URL
else:
DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1"
self.client = OpenAI(api_key=api_key, base_url=DEFAULT_OPENAI_API_BASE)
self.api_key = api_key
self.user_api_key = user_api_key
effective_base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=self.api_key, base_url=effective_base_url)
self.storage = StorageCreator.get_storage()
def _clean_messages_openai(self, messages):
@@ -33,7 +40,6 @@ class OpenAILLM(BaseLLM):
if role == "model":
role = "assistant"
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
@@ -107,7 +113,6 @@ class OpenAILLM(BaseLLM):
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages
def _raw_gen(
@@ -132,10 +137,8 @@ class OpenAILLM(BaseLLM):
if tools:
request_params["tools"] = tools
if response_format:
request_params["response_format"] = response_format
response = self.client.chat.completions.create(**request_params)
if tools:
@@ -165,10 +168,8 @@ class OpenAILLM(BaseLLM):
if tools:
request_params["tools"] = tools
if response_format:
request_params["response_format"] = response_format
response = self.client.chat.completions.create(**request_params)
try:
@@ -194,7 +195,6 @@ class OpenAILLM(BaseLLM):
def prepare_structured_output_format(self, json_schema):
if not json_schema:
return None
try:
def add_additional_properties_false(schema_obj):
@@ -204,11 +204,11 @@ class OpenAILLM(BaseLLM):
if schema_copy.get("type") == "object":
schema_copy["additionalProperties"] = False
# Ensure 'required' includes all properties for OpenAI strict mode
if "properties" in schema_copy:
schema_copy["required"] = list(
schema_copy["properties"].keys()
)
for key, value in schema_copy.items():
if key == "properties" and isinstance(value, dict):
schema_copy[key] = {
@@ -224,7 +224,6 @@ class OpenAILLM(BaseLLM):
add_additional_properties_false(sub_schema)
for sub_schema in value
]
return schema_copy
return schema_obj
@@ -243,7 +242,6 @@ class OpenAILLM(BaseLLM):
}
return result
except Exception as e:
logging.error(f"Error preparing structured output format: {e}")
return None
@@ -277,21 +275,19 @@ class OpenAILLM(BaseLLM):
"""
if not attachments:
return messages
prepared_messages = messages.copy()
# Find the user message to attach file_id to the last one
user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user":
user_message_index = i
break
if user_message_index is None:
user_message = {"role": "user", "content": []}
prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1
if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [
@@ -299,7 +295,6 @@ class OpenAILLM(BaseLLM):
]
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
@@ -326,6 +321,7 @@ class OpenAILLM(BaseLLM):
}
)
# Handle PDFs using the file API
elif mime_type == "application/pdf":
try:
file_id = self._upload_file_to_openai(attachment)
@@ -341,7 +337,6 @@ class OpenAILLM(BaseLLM):
"text": f"File content:\n\n{attachment['content']}",
}
)
return prepared_messages
def _get_base64_image(self, attachment):
@@ -357,7 +352,6 @@ class OpenAILLM(BaseLLM):
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
try:
with self.storage.get_file(file_path) as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
@@ -381,12 +375,10 @@ class OpenAILLM(BaseLLM):
if "openai_file_id" in attachment:
return attachment["openai_file_id"]
file_path = attachment.get("path")
if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
try:
file_id = self.storage.process_file(
file_path,
@@ -404,7 +396,6 @@ class OpenAILLM(BaseLLM):
attachments_collection.update_one(
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
)
return file_id
except Exception as e:
logging.error(f"Error uploading file to OpenAI: {e}", exc_info=True)

View File

@@ -16,7 +16,7 @@ class ClassicRAG(BaseRetriever):
prompt="",
chunks=2,
doc_token_limit=50000,
gpt_model="docsgpt",
model_id="docsgpt-local",
user_api_key=None,
llm_name=settings.LLM_PROVIDER,
api_key=settings.API_KEY,
@@ -40,7 +40,7 @@ class ClassicRAG(BaseRetriever):
f"ClassicRAG initialized with chunks={self.chunks}, user_api_key={user_identifier}, "
f"sources={'active_docs' in source and source['active_docs'] is not None}"
)
self.gpt_model = gpt_model
self.model_id = model_id
self.doc_token_limit = doc_token_limit
self.user_api_key = user_api_key
self.llm_name = llm_name
@@ -100,7 +100,7 @@ class ClassicRAG(BaseRetriever):
]
try:
rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages)
rephrased_query = self.llm.gen(model=self.model_id, messages=messages)
print(f"Rephrased query: {rephrased_query}")
return rephrased_query if rephrased_query else self.original_question
except Exception as e:

View File

@@ -7,6 +7,8 @@ import tiktoken
from flask import jsonify, make_response
from werkzeug.utils import secure_filename
from application.core.model_utils import get_token_limit
from application.core.settings import settings
@@ -75,11 +77,9 @@ def count_tokens_docs(docs):
def calculate_doc_token_budget(
gpt_model: str = "gpt-4o", history_token_limit: int = 2000
model_id: str = "gpt-4o", history_token_limit: int = 2000
) -> int:
total_context = settings.LLM_TOKEN_LIMITS.get(
gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT
)
total_context = get_token_limit(model_id)
reserved = sum(settings.RESERVED_TOKENS.values())
doc_budget = total_context - history_token_limit - reserved
return max(doc_budget, 1000)
@@ -144,16 +144,13 @@ def get_hash(data):
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
def limit_chat_history(history, max_token_limit=None, model_id="docsgpt-local"):
"""Limit chat history to fit within token limit."""
from application.core.settings import settings
model_token_limit = get_token_limit(model_id)
max_token_limit = (
max_token_limit
if max_token_limit
and max_token_limit
< settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT)
else settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT)
if max_token_limit and max_token_limit < model_token_limit
else model_token_limit
)
if not history:
@@ -205,37 +202,44 @@ def clean_text_for_tts(text: str) -> str:
clean text for Text-to-Speech processing.
"""
# Handle code blocks and links
text = re.sub(r'```mermaid[\s\S]*?```', ' flowchart, ', text) ## ```mermaid...```
text = re.sub(r'```[\s\S]*?```', ' code block, ', text) ## ```code```
text = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', text) ## [text](url)
text = re.sub(r'!\[([^\]]*)\]\([^\)]+\)', '', text) ## ![alt](url)
text = re.sub(r"```mermaid[\s\S]*?```", " flowchart, ", text) ## ```mermaid...```
text = re.sub(r"```[\s\S]*?```", " code block, ", text) ## ```code```
text = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", text) ## [text](url)
text = re.sub(r"!\[([^\]]*)\]\([^\)]+\)", "", text) ## ![alt](url)
# Remove markdown formatting
text = re.sub(r'`([^`]+)`', r'\1', text) ## `code`
text = re.sub(r'\{([^}]*)\}', r' \1 ', text) ## {text}
text = re.sub(r'[{}]', ' ', text) ## unmatched {}
text = re.sub(r'\[([^\]]+)\]', r' \1 ', text) ## [text]
text = re.sub(r'[\[\]]', ' ', text) ## unmatched []
text = re.sub(r'(\*\*|__)(.*?)\1', r'\2', text) ## **bold** __bold__
text = re.sub(r'(\*|_)(.*?)\1', r'\2', text) ## *italic* _italic_
text = re.sub(r'^#{1,6}\s+', '', text, flags=re.MULTILINE) ## # headers
text = re.sub(r'^>\s+', '', text, flags=re.MULTILINE) ## > blockquotes
text = re.sub(r'^[\s]*[-\*\+]\s+', '', text, flags=re.MULTILINE) ## - * + lists
text = re.sub(r'^[\s]*\d+\.\s+', '', text, flags=re.MULTILINE) ## 1. numbered lists
text = re.sub(r'^[\*\-_]{3,}\s*$', '', text, flags=re.MULTILINE) ## --- *** ___ rules
text = re.sub(r'<[^>]*>', '', text) ## <html> tags
#Remove non-ASCII (emojis, special Unicode)
text = re.sub(r'[^\x20-\x7E\n\r\t]', '', text)
text = re.sub(r"`([^`]+)`", r"\1", text) ## `code`
text = re.sub(r"\{([^}]*)\}", r" \1 ", text) ## {text}
text = re.sub(r"[{}]", " ", text) ## unmatched {}
text = re.sub(r"\[([^\]]+)\]", r" \1 ", text) ## [text]
text = re.sub(r"[\[\]]", " ", text) ## unmatched []
text = re.sub(r"(\*\*|__)(.*?)\1", r"\2", text) ## **bold** __bold__
text = re.sub(r"(\*|_)(.*?)\1", r"\2", text) ## *italic* _italic_
text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE) ## # headers
text = re.sub(r"^>\s+", "", text, flags=re.MULTILINE) ## > blockquotes
text = re.sub(r"^[\s]*[-\*\+]\s+", "", text, flags=re.MULTILINE) ## - * + lists
text = re.sub(r"^[\s]*\d+\.\s+", "", text, flags=re.MULTILINE) ## 1. numbered lists
text = re.sub(
r"^[\*\-_]{3,}\s*$", "", text, flags=re.MULTILINE
) ## --- *** ___ rules
text = re.sub(r"<[^>]*>", "", text) ## <html> tags
#Replace special sequences
text = re.sub(r'-->', ', ', text) ## -->
text = re.sub(r'<--', ', ', text) ## <--
text = re.sub(r'=>', ', ', text) ## =>
text = re.sub(r'::', ' ', text) ## ::
# Remove non-ASCII (emojis, special Unicode)
#Normalize whitespace
text = re.sub(r'\s+', ' ', text)
text = re.sub(r"[^\x20-\x7E\n\r\t]", "", text)
# Replace special sequences
text = re.sub(r"-->", ", ", text) ## -->
text = re.sub(r"<--", ", ", text) ## <--
text = re.sub(r"=>", ", ", text) ## =>
text = re.sub(r"::", " ", text) ## ::
# Normalize whitespace
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text

View File

@@ -165,7 +165,7 @@ def run_agent_logic(agent_config, input_data):
agent_type,
endpoint="webhook",
llm_name=settings.LLM_PROVIDER,
gpt_model=settings.LLM_NAME,
model_id=settings.LLM_NAME,
api_key=settings.API_KEY,
user_api_key=user_api_key,
prompt=prompt,
@@ -180,7 +180,7 @@ def run_agent_logic(agent_config, input_data):
prompt=prompt,
chunks=chunks,
token_limit=settings.DEFAULT_MAX_HISTORY,
gpt_model=settings.LLM_NAME,
model_id=settings.LLM_NAME,
user_api_key=user_api_key,
decoded_token=decoded_token,
)