diff --git a/application/agents/agent_creator.py b/application/agents/agent_creator.py
index bf37d4ec..44e89552 100644
--- a/application/agents/agent_creator.py
+++ b/application/agents/agent_creator.py
@@ -1,5 +1,8 @@
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
+import logging
+
+logger = logging.getLogger(__name__)
class AgentCreator:
@@ -13,4 +16,5 @@ class AgentCreator:
agent_class = cls.agents.get(type.lower())
if not agent_class:
raise ValueError(f"No agent class found for type {type}")
+
return agent_class(*args, **kwargs)
diff --git a/application/agents/base.py b/application/agents/base.py
index 27428fc3..dbf15a1f 100644
--- a/application/agents/base.py
+++ b/application/agents/base.py
@@ -21,7 +21,7 @@ class BaseAgent(ABC):
self,
endpoint: str,
llm_name: str,
- gpt_model: str,
+ model_id: str,
api_key: str,
user_api_key: Optional[str] = None,
prompt: str = "",
@@ -37,7 +37,7 @@ class BaseAgent(ABC):
):
self.endpoint = endpoint
self.llm_name = llm_name
- self.gpt_model = gpt_model
+ self.model_id = model_id
self.api_key = api_key
self.user_api_key = user_api_key
self.prompt = prompt
@@ -52,6 +52,7 @@ class BaseAgent(ABC):
api_key=api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
+ model_id=model_id,
)
self.retrieved_docs = retrieved_docs or []
self.llm_handler = LLMHandlerCreator.create_handler(
@@ -316,7 +317,7 @@ class BaseAgent(ABC):
return messages
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
- gen_kwargs = {"model": self.gpt_model, "messages": messages}
+ gen_kwargs = {"model": self.model_id, "messages": messages}
if (
hasattr(self.llm, "_supports_tools")
diff --git a/application/agents/react_agent.py b/application/agents/react_agent.py
index 49dd29d8..116fa4aa 100644
--- a/application/agents/react_agent.py
+++ b/application/agents/react_agent.py
@@ -86,7 +86,7 @@ class ReActAgent(BaseAgent):
messages = [{"role": "user", "content": plan_prompt}]
plan_stream = self.llm.gen_stream(
- model=self.gpt_model,
+ model=self.model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
@@ -151,7 +151,7 @@ class ReActAgent(BaseAgent):
messages = [{"role": "user", "content": final_prompt}]
final_stream = self.llm.gen_stream(
- model=self.gpt_model, messages=messages, tools=None
+ model=self.model_id, messages=messages, tools=None
)
if log_context:
diff --git a/application/api/answer/routes/answer.py b/application/api/answer/routes/answer.py
index 87d80059..bc7ec58c 100644
--- a/application/api/answer/routes/answer.py
+++ b/application/api/answer/routes/answer.py
@@ -54,6 +54,10 @@ class AnswerResource(Resource, BaseAnswerResource):
default=True,
description="Whether to save the conversation",
),
+ "model_id": fields.String(
+ required=False,
+ description="Model ID to use for this request",
+ ),
"passthrough": fields.Raw(
required=False,
description="Dynamic parameters to inject into prompt template",
@@ -97,6 +101,7 @@ class AnswerResource(Resource, BaseAnswerResource):
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
+ model_id=processor.model_id,
)
stream_result = self.process_response_stream(stream)
diff --git a/application/api/answer/routes/base.py b/application/api/answer/routes/base.py
index 43e83ed2..aefb66c6 100644
--- a/application/api/answer/routes/base.py
+++ b/application/api/answer/routes/base.py
@@ -7,11 +7,16 @@ from flask import jsonify, make_response, Response
from flask_restx import Namespace
from application.api.answer.services.conversation_service import ConversationService
+from application.core.model_utils import (
+ get_api_key_for_provider,
+ get_default_model_id,
+ get_provider_from_model_id,
+)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
-from application.utils import check_required_fields, get_gpt_model
+from application.utils import check_required_fields
logger = logging.getLogger(__name__)
@@ -27,7 +32,7 @@ class BaseAnswerResource:
db = mongo[settings.MONGO_DB_NAME]
self.db = db
self.user_logs_collection = db["user_logs"]
- self.gpt_model = get_gpt_model()
+ self.default_model_id = get_default_model_id()
self.conversation_service = ConversationService()
def validate_request(
@@ -54,7 +59,6 @@ class BaseAnswerResource:
api_key = agent_config.get("user_api_key")
if not api_key:
return None
-
agents_collection = self.db["agents"]
agent = agents_collection.find_one({"key": api_key})
@@ -62,7 +66,6 @@ class BaseAnswerResource:
return make_response(
jsonify({"success": False, "message": "Invalid API key."}), 401
)
-
limited_token_mode_raw = agent.get("limited_token_mode", False)
limited_request_mode_raw = agent.get("limited_request_mode", False)
@@ -110,15 +113,12 @@ class BaseAnswerResource:
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
else:
daily_token_usage = 0
-
if limited_request_mode:
daily_request_usage = token_usage_collection.count_documents(match_query)
else:
daily_request_usage = 0
-
if not limited_token_mode and not limited_request_mode:
return None
-
token_exceeded = (
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
)
@@ -138,7 +138,6 @@ class BaseAnswerResource:
),
429,
)
-
return None
def complete_stream(
@@ -155,6 +154,7 @@ class BaseAnswerResource:
agent_id: Optional[str] = None,
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
+ model_id: Optional[str] = None,
) -> Generator[str, None, None]:
"""
Generator function that streams the complete conversation response.
@@ -173,6 +173,7 @@ class BaseAnswerResource:
agent_id: ID of agent used
is_shared_usage: Flag for shared agent usage
shared_token: Token for shared agent
+ model_id: Model ID used for the request
retrieved_docs: Pre-fetched documents for sources (optional)
Yields:
@@ -220,7 +221,6 @@ class BaseAnswerResource:
elif "type" in line:
data = json.dumps(line)
yield f"data: {data}\n\n"
-
if is_structured and structured_chunks:
structured_data = {
"type": "structured_answer",
@@ -230,15 +230,22 @@ class BaseAnswerResource:
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
-
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
+ provider = (
+ get_provider_from_model_id(model_id)
+ if model_id
+ else settings.LLM_PROVIDER
+ )
+ system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
+
llm = LLMCreator.create_llm(
- settings.LLM_PROVIDER,
- api_key=settings.API_KEY,
+ provider or settings.LLM_PROVIDER,
+ api_key=system_api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
+ model_id=model_id,
)
if should_save_conversation:
@@ -250,7 +257,7 @@ class BaseAnswerResource:
source_log_docs,
tool_calls,
llm,
- self.gpt_model,
+ model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
@@ -280,12 +287,11 @@ class BaseAnswerResource:
log_data["structured_output"] = True
if schema_info:
log_data["schema"] = schema_info
-
# Clean up text fields to be no longer than 10000 characters
+
for key, value in log_data.items():
if isinstance(value, str) and len(value) > 10000:
log_data[key] = value[:10000]
-
self.user_logs_collection.insert_one(log_data)
data = json.dumps({"type": "end"})
@@ -293,6 +299,7 @@ class BaseAnswerResource:
except GeneratorExit:
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
# Save partial response
+
if should_save_conversation and response_full:
try:
if isNoneDoc:
@@ -312,7 +319,7 @@ class BaseAnswerResource:
source_log_docs,
tool_calls,
llm,
- self.gpt_model,
+ model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
@@ -369,7 +376,7 @@ class BaseAnswerResource:
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
- return None, None, None, None, event["error"]
+ return None, None, None, None, event["error"], None
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
@@ -377,8 +384,7 @@ class BaseAnswerResource:
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
- return None, None, None, None, "Stream ended unexpectedly"
-
+ return None, None, None, None, "Stream ended unexpectedly", None
result = (
conversation_id,
response_full,
@@ -390,7 +396,6 @@ class BaseAnswerResource:
if is_structured:
result = result + ({"structured": True, "schema": schema_info},)
-
return result
def error_stream_generate(self, err_response):
diff --git a/application/api/answer/routes/stream.py b/application/api/answer/routes/stream.py
index 92e41c14..b2827a93 100644
--- a/application/api/answer/routes/stream.py
+++ b/application/api/answer/routes/stream.py
@@ -57,6 +57,10 @@ class StreamResource(Resource, BaseAnswerResource):
default=True,
description="Whether to save the conversation",
),
+ "model_id": fields.String(
+ required=False,
+ description="Model ID to use for this request",
+ ),
"attachments": fields.List(
fields.String, required=False, description="List of attachment IDs"
),
@@ -101,6 +105,7 @@ class StreamResource(Resource, BaseAnswerResource):
agent_id=data.get("agent_id"),
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
+ model_id=processor.model_id,
),
mimetype="text/event-stream",
)
diff --git a/application/api/answer/services/conversation_service.py b/application/api/answer/services/conversation_service.py
index eca842d6..0e98983e 100644
--- a/application/api/answer/services/conversation_service.py
+++ b/application/api/answer/services/conversation_service.py
@@ -52,7 +52,7 @@ class ConversationService:
sources: List[Dict[str, Any]],
tool_calls: List[Dict[str, Any]],
llm: Any,
- gpt_model: str,
+ model_id: str,
decoded_token: Dict[str, Any],
index: Optional[int] = None,
api_key: Optional[str] = None,
@@ -66,7 +66,7 @@ class ConversationService:
if not user_id:
raise ValueError("User ID not found in token")
current_time = datetime.now(timezone.utc)
-
+
# clean up in sources array such that we save max 1k characters for text part
for source in sources:
if "text" in source and isinstance(source["text"], str):
@@ -90,6 +90,7 @@ class ConversationService:
f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time,
f"queries.{index}.attachments": attachment_ids,
+ f"queries.{index}.model_id": model_id,
}
},
)
@@ -120,6 +121,7 @@ class ConversationService:
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
+ "model_id": model_id,
}
}
},
@@ -146,7 +148,7 @@ class ConversationService:
]
completion = llm.gen(
- model=gpt_model, messages=messages_summary, max_tokens=30
+ model=model_id, messages=messages_summary, max_tokens=30
)
conversation_data = {
@@ -162,6 +164,7 @@ class ConversationService:
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
+ "model_id": model_id,
}
],
}
diff --git a/application/api/answer/services/stream_processor.py b/application/api/answer/services/stream_processor.py
index bb890937..586e7696 100644
--- a/application/api/answer/services/stream_processor.py
+++ b/application/api/answer/services/stream_processor.py
@@ -12,12 +12,17 @@ from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.conversation_service import ConversationService
from application.api.answer.services.prompt_renderer import PromptRenderer
+from application.core.model_utils import (
+ get_api_key_for_provider,
+ get_default_model_id,
+ get_provider_from_model_id,
+ validate_model_id,
+)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import (
calculate_doc_token_budget,
- get_gpt_model,
limit_chat_history,
)
@@ -83,7 +88,7 @@ class StreamProcessor:
self.retriever_config = {}
self.is_shared_usage = False
self.shared_token = None
- self.gpt_model = get_gpt_model()
+ self.model_id: Optional[str] = None
self.conversation_service = ConversationService()
self.prompt_renderer = PromptRenderer()
self._prompt_content: Optional[str] = None
@@ -91,6 +96,7 @@ class StreamProcessor:
def initialize(self):
"""Initialize all required components for processing"""
+ self._validate_and_set_model()
self._configure_agent()
self._configure_source()
self._configure_retriever()
@@ -112,7 +118,7 @@ class StreamProcessor:
]
else:
self.history = limit_chat_history(
- json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
+ json.loads(self.data.get("history", "[]")), model_id=self.model_id
)
def _process_attachments(self):
@@ -143,6 +149,25 @@ class StreamProcessor:
)
return attachments
+ def _validate_and_set_model(self):
+ """Validate and set model_id from request"""
+ from application.core.model_settings import ModelRegistry
+
+ requested_model = self.data.get("model_id")
+
+ if requested_model:
+ if not validate_model_id(requested_model):
+ registry = ModelRegistry.get_instance()
+ available_models = [m.id for m in registry.get_enabled_models()]
+ raise ValueError(
+ f"Invalid model_id '{requested_model}'. "
+ f"Available models: {', '.join(available_models[:5])}"
+ + (f" and {len(available_models) - 5} more" if len(available_models) > 5 else "")
+ )
+ self.model_id = requested_model
+ else:
+ self.model_id = get_default_model_id()
+
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control"""
if not agent_id:
@@ -322,7 +347,7 @@ class StreamProcessor:
def _configure_retriever(self):
history_token_limit = int(self.data.get("token_limit", 2000))
doc_token_limit = calculate_doc_token_budget(
- gpt_model=self.gpt_model, history_token_limit=history_token_limit
+ model_id=self.model_id, history_token_limit=history_token_limit
)
self.retriever_config = {
@@ -344,7 +369,7 @@ class StreamProcessor:
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
chunks=self.retriever_config["chunks"],
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
- gpt_model=self.gpt_model,
+ model_id=self.model_id,
user_api_key=self.agent_config["user_api_key"],
decoded_token=self.decoded_token,
)
@@ -626,12 +651,19 @@ class StreamProcessor:
tools_data=tools_data,
)
+ provider = (
+ get_provider_from_model_id(self.model_id)
+ if self.model_id
+ else settings.LLM_PROVIDER
+ )
+ system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
+
return AgentCreator.create_agent(
self.agent_config["agent_type"],
endpoint="stream",
- llm_name=settings.LLM_PROVIDER,
- gpt_model=self.gpt_model,
- api_key=settings.API_KEY,
+ llm_name=provider or settings.LLM_PROVIDER,
+ model_id=self.model_id,
+ api_key=system_api_key,
user_api_key=self.agent_config["user_api_key"],
prompt=rendered_prompt,
chat_history=self.history,
diff --git a/application/api/user/agents/routes.py b/application/api/user/agents/routes.py
index 6c43342f..bcf68b56 100644
--- a/application/api/user/agents/routes.py
+++ b/application/api/user/agents/routes.py
@@ -95,6 +95,8 @@ class GetAgent(Resource):
"shared": agent.get("shared_publicly", False),
"shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""),
+ "models": agent.get("models", []),
+ "default_model_id": agent.get("default_model_id", ""),
}
return make_response(jsonify(data), 200)
except Exception as e:
@@ -172,6 +174,8 @@ class GetAgents(Resource):
"shared": agent.get("shared_publicly", False),
"shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""),
+ "models": agent.get("models", []),
+ "default_model_id": agent.get("default_model_id", ""),
}
for agent in agents
if "source" in agent or "retriever" in agent
@@ -230,6 +234,14 @@ class CreateAgent(Resource):
required=False,
description="Request limit for the agent in limited mode",
),
+ "models": fields.List(
+ fields.String,
+ required=False,
+ description="List of available model IDs for this agent",
+ ),
+ "default_model_id": fields.String(
+ required=False, description="Default model ID for this agent"
+ ),
},
)
@@ -258,6 +270,11 @@ class CreateAgent(Resource):
data["json_schema"] = json.loads(data["json_schema"])
except json.JSONDecodeError:
data["json_schema"] = None
+ if "models" in data:
+ try:
+ data["models"] = json.loads(data["models"])
+ except json.JSONDecodeError:
+ data["models"] = []
print(f"Received data: {data}")
# Validate JSON schema if provided
@@ -399,6 +416,8 @@ class CreateAgent(Resource):
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
"lastUsedAt": None,
"key": key,
+ "models": data.get("models", []),
+ "default_model_id": data.get("default_model_id", ""),
}
if new_agent["chunks"] == "":
new_agent["chunks"] = "2"
@@ -464,6 +483,14 @@ class UpdateAgent(Resource):
required=False,
description="Request limit for the agent in limited mode",
),
+ "models": fields.List(
+ fields.String,
+ required=False,
+ description="List of available model IDs for this agent",
+ ),
+ "default_model_id": fields.String(
+ required=False, description="Default model ID for this agent"
+ ),
},
)
@@ -487,7 +514,7 @@ class UpdateAgent(Resource):
data = request.get_json()
else:
data = request.form.to_dict()
- json_fields = ["tools", "sources", "json_schema"]
+ json_fields = ["tools", "sources", "json_schema", "models"]
for field in json_fields:
if field in data and data[field]:
try:
@@ -555,6 +582,8 @@ class UpdateAgent(Resource):
"token_limit",
"limited_request_mode",
"request_limit",
+ "models",
+ "default_model_id",
]
for field in allowed_fields:
diff --git a/application/api/user/models/__init__.py b/application/api/user/models/__init__.py
new file mode 100644
index 00000000..f32afa11
--- /dev/null
+++ b/application/api/user/models/__init__.py
@@ -0,0 +1,3 @@
+from .routes import models_ns
+
+__all__ = ["models_ns"]
diff --git a/application/api/user/models/routes.py b/application/api/user/models/routes.py
new file mode 100644
index 00000000..886999b2
--- /dev/null
+++ b/application/api/user/models/routes.py
@@ -0,0 +1,25 @@
+from flask import current_app, jsonify, make_response
+from flask_restx import Namespace, Resource
+
+from application.core.model_settings import ModelRegistry
+
+models_ns = Namespace("models", description="Available models", path="/api")
+
+
+@models_ns.route("/models")
+class ModelsListResource(Resource):
+ def get(self):
+ """Get list of available models with their capabilities."""
+ try:
+ registry = ModelRegistry.get_instance()
+ models = registry.get_enabled_models()
+
+ response = {
+ "models": [model.to_dict() for model in models],
+ "default_model_id": registry.default_model_id,
+ "count": len(models),
+ }
+ except Exception as err:
+ current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
+ return make_response(jsonify({"success": False}), 500)
+ return make_response(jsonify(response), 200)
diff --git a/application/api/user/routes.py b/application/api/user/routes.py
index 1e0dbb4e..82e395c5 100644
--- a/application/api/user/routes.py
+++ b/application/api/user/routes.py
@@ -10,6 +10,7 @@ from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns
from .analytics import analytics_ns
from .attachments import attachments_ns
from .conversations import conversations_ns
+from .models import models_ns
from .prompts import prompts_ns
from .sharing import sharing_ns
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
@@ -27,6 +28,9 @@ api.add_namespace(attachments_ns)
# Conversations
api.add_namespace(conversations_ns)
+# Models
+api.add_namespace(models_ns)
+
# Agents (main, sharing, webhooks)
api.add_namespace(agents_ns)
api.add_namespace(agents_sharing_ns)
diff --git a/application/core/model_configs.py b/application/core/model_configs.py
new file mode 100644
index 00000000..b802ee27
--- /dev/null
+++ b/application/core/model_configs.py
@@ -0,0 +1,223 @@
+"""
+Model configurations for all supported LLM providers.
+"""
+
+from application.core.model_settings import (
+ AvailableModel,
+ ModelCapabilities,
+ ModelProvider,
+)
+
+OPENAI_ATTACHMENTS = [
+ "application/pdf",
+ "image/png",
+ "image/jpeg",
+ "image/jpg",
+ "image/webp",
+ "image/gif",
+]
+
+GOOGLE_ATTACHMENTS = [
+ "application/pdf",
+ "image/png",
+ "image/jpeg",
+ "image/jpg",
+ "image/webp",
+ "image/gif",
+]
+
+
+OPENAI_MODELS = [
+ AvailableModel(
+ id="gpt-4o",
+ provider=ModelProvider.OPENAI,
+ display_name="GPT-4 Omni",
+ description="Latest and most capable model",
+ capabilities=ModelCapabilities(
+ supports_tools=True,
+ supports_structured_output=True,
+ supported_attachment_types=OPENAI_ATTACHMENTS,
+ context_window=128000,
+ ),
+ ),
+ AvailableModel(
+ id="gpt-4o-mini",
+ provider=ModelProvider.OPENAI,
+ display_name="GPT-4 Omni Mini",
+ description="Fast and efficient",
+ capabilities=ModelCapabilities(
+ supports_tools=True,
+ supports_structured_output=True,
+ supported_attachment_types=OPENAI_ATTACHMENTS,
+ context_window=128000,
+ ),
+ ),
+ AvailableModel(
+ id="gpt-4-turbo",
+ provider=ModelProvider.OPENAI,
+ display_name="GPT-4 Turbo",
+ description="Fast GPT-4 with 128k context",
+ capabilities=ModelCapabilities(
+ supports_tools=True,
+ supports_structured_output=True,
+ supported_attachment_types=OPENAI_ATTACHMENTS,
+ context_window=128000,
+ ),
+ ),
+ AvailableModel(
+ id="gpt-4",
+ provider=ModelProvider.OPENAI,
+ display_name="GPT-4",
+ description="Most capable model",
+ capabilities=ModelCapabilities(
+ supports_tools=True,
+ supports_structured_output=True,
+ supported_attachment_types=OPENAI_ATTACHMENTS,
+ context_window=8192,
+ ),
+ ),
+ AvailableModel(
+ id="gpt-3.5-turbo",
+ provider=ModelProvider.OPENAI,
+ display_name="GPT-3.5 Turbo",
+ description="Fast and cost-effective",
+ capabilities=ModelCapabilities(
+ supports_tools=True,
+ context_window=4096,
+ ),
+ ),
+]
+
+
+ANTHROPIC_MODELS = [
+ AvailableModel(
+ id="claude-3-5-sonnet-20241022",
+ provider=ModelProvider.ANTHROPIC,
+ display_name="Claude 3.5 Sonnet (Latest)",
+ description="Latest Claude 3.5 Sonnet with enhanced capabilities",
+ capabilities=ModelCapabilities(
+ supports_tools=True,
+ context_window=200000,
+ ),
+ ),
+ AvailableModel(
+ id="claude-3-5-sonnet",
+ provider=ModelProvider.ANTHROPIC,
+ display_name="Claude 3.5 Sonnet",
+ description="Balanced performance and capability",
+ capabilities=ModelCapabilities(
+ supports_tools=True,
+ context_window=200000,
+ ),
+ ),
+ AvailableModel(
+ id="claude-3-opus",
+ provider=ModelProvider.ANTHROPIC,
+ display_name="Claude 3 Opus",
+ description="Most capable Claude model",
+ capabilities=ModelCapabilities(
+ supports_tools=True,
+ context_window=200000,
+ ),
+ ),
+ AvailableModel(
+ id="claude-3-haiku",
+ provider=ModelProvider.ANTHROPIC,
+ display_name="Claude 3 Haiku",
+ description="Fastest Claude model",
+ capabilities=ModelCapabilities(
+ supports_tools=True,
+ context_window=200000,
+ ),
+ ),
+]
+
+
+GOOGLE_MODELS = [
+ AvailableModel(
+ id="gemini-flash-latest",
+ provider=ModelProvider.GOOGLE,
+ display_name="Gemini Flash (Latest)",
+ description="Latest experimental Gemini model",
+ capabilities=ModelCapabilities(
+ supports_tools=True,
+ supports_structured_output=True,
+ supported_attachment_types=GOOGLE_ATTACHMENTS,
+ context_window=int(1e6),
+ ),
+ ),
+ AvailableModel(
+ id="gemini-flash-lite-latest",
+ provider=ModelProvider.GOOGLE,
+ display_name="Gemini Flash Lite (Latest)",
+ description="Fast with huge context window",
+ capabilities=ModelCapabilities(
+ supports_tools=True,
+ supports_structured_output=True,
+ supported_attachment_types=GOOGLE_ATTACHMENTS,
+ context_window=int(1e6),
+ ),
+ ),
+ AvailableModel(
+ id="gemini-2.5-pro",
+ provider=ModelProvider.GOOGLE,
+ display_name="Gemini 2.5 Pro",
+ description="Most capable Gemini model",
+ capabilities=ModelCapabilities(
+ supports_tools=True,
+ supports_structured_output=True,
+ supported_attachment_types=GOOGLE_ATTACHMENTS,
+ context_window=2000000,
+ ),
+ ),
+]
+
+
+GROQ_MODELS = [
+ AvailableModel(
+ id="llama-3.3-70b-versatile",
+ provider=ModelProvider.GROQ,
+ display_name="Llama 3.3 70B",
+ description="Latest Llama model with high-speed inference",
+ capabilities=ModelCapabilities(
+ supports_tools=True,
+ context_window=128000,
+ ),
+ ),
+ AvailableModel(
+ id="llama-3.1-8b-instant",
+ provider=ModelProvider.GROQ,
+ display_name="Llama 3.1 8B",
+ description="Ultra-fast inference",
+ capabilities=ModelCapabilities(
+ supports_tools=True,
+ context_window=128000,
+ ),
+ ),
+ AvailableModel(
+ id="mixtral-8x7b-32768",
+ provider=ModelProvider.GROQ,
+ display_name="Mixtral 8x7B",
+ description="High-speed inference with tools",
+ capabilities=ModelCapabilities(
+ supports_tools=True,
+ context_window=32768,
+ ),
+ ),
+]
+
+
+AZURE_OPENAI_MODELS = [
+ AvailableModel(
+ id="azure-gpt-4",
+ provider=ModelProvider.AZURE_OPENAI,
+ display_name="Azure OpenAI GPT-4",
+ description="Azure-hosted GPT model",
+ capabilities=ModelCapabilities(
+ supports_tools=True,
+ supports_structured_output=True,
+ supported_attachment_types=OPENAI_ATTACHMENTS,
+ context_window=8192,
+ ),
+ ),
+]
diff --git a/application/core/model_settings.py b/application/core/model_settings.py
new file mode 100644
index 00000000..87325ac3
--- /dev/null
+++ b/application/core/model_settings.py
@@ -0,0 +1,236 @@
+import logging
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Dict, List, Optional
+
+logger = logging.getLogger(__name__)
+
+
+class ModelProvider(str, Enum):
+ OPENAI = "openai"
+ AZURE_OPENAI = "azure_openai"
+ ANTHROPIC = "anthropic"
+ GROQ = "groq"
+ GOOGLE = "google"
+ HUGGINGFACE = "huggingface"
+ LLAMA_CPP = "llama.cpp"
+ DOCSGPT = "docsgpt"
+ PREMAI = "premai"
+ SAGEMAKER = "sagemaker"
+ NOVITA = "novita"
+
+
+@dataclass
+class ModelCapabilities:
+ supports_tools: bool = False
+ supports_structured_output: bool = False
+ supports_streaming: bool = True
+ supported_attachment_types: List[str] = field(default_factory=list)
+ context_window: int = 128000
+ input_cost_per_token: Optional[float] = None
+ output_cost_per_token: Optional[float] = None
+
+
+@dataclass
+class AvailableModel:
+ id: str
+ provider: ModelProvider
+ display_name: str
+ description: str = ""
+ capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
+ enabled: bool = True
+ base_url: Optional[str] = None
+
+ def to_dict(self) -> Dict:
+ result = {
+ "id": self.id,
+ "provider": self.provider.value,
+ "display_name": self.display_name,
+ "description": self.description,
+ "supported_attachment_types": self.capabilities.supported_attachment_types,
+ "supports_tools": self.capabilities.supports_tools,
+ "supports_structured_output": self.capabilities.supports_structured_output,
+ "supports_streaming": self.capabilities.supports_streaming,
+ "context_window": self.capabilities.context_window,
+ "enabled": self.enabled,
+ }
+ if self.base_url:
+ result["base_url"] = self.base_url
+ return result
+
+
+class ModelRegistry:
+ _instance = None
+ _initialized = False
+
+ def __new__(cls):
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ return cls._instance
+
+ def __init__(self):
+ if not ModelRegistry._initialized:
+ self.models: Dict[str, AvailableModel] = {}
+ self.default_model_id: Optional[str] = None
+ self._load_models()
+ ModelRegistry._initialized = True
+
+ @classmethod
+ def get_instance(cls) -> "ModelRegistry":
+ return cls()
+
+ def _load_models(self):
+ from application.core.settings import settings
+
+ self.models.clear()
+
+ self._add_docsgpt_models(settings)
+ if settings.OPENAI_API_KEY or (
+ settings.LLM_PROVIDER == "openai" and settings.API_KEY
+ ):
+ self._add_openai_models(settings)
+ if settings.OPENAI_API_BASE or (
+ settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
+ ):
+ self._add_azure_openai_models(settings)
+ if settings.ANTHROPIC_API_KEY or (
+ settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
+ ):
+ self._add_anthropic_models(settings)
+ if settings.GOOGLE_API_KEY or (
+ settings.LLM_PROVIDER == "google" and settings.API_KEY
+ ):
+ self._add_google_models(settings)
+ if settings.GROQ_API_KEY or (
+ settings.LLM_PROVIDER == "groq" and settings.API_KEY
+ ):
+ self._add_groq_models(settings)
+ if settings.HUGGINGFACE_API_KEY or (
+ settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
+ ):
+ self._add_huggingface_models(settings)
+ # Default model selection
+
+ if settings.LLM_NAME and settings.LLM_NAME in self.models:
+ self.default_model_id = settings.LLM_NAME
+ elif settings.LLM_PROVIDER and settings.API_KEY:
+ for model_id, model in self.models.items():
+ if model.provider.value == settings.LLM_PROVIDER:
+ self.default_model_id = model_id
+ break
+ else:
+ self.default_model_id = next(iter(self.models.keys()))
+ logger.info(
+ f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
+ )
+
+ def _add_openai_models(self, settings):
+ from application.core.model_configs import OPENAI_MODELS
+
+ if settings.OPENAI_API_KEY:
+ for model in OPENAI_MODELS:
+ self.models[model.id] = model
+ return
+ if settings.LLM_PROVIDER == "openai" and settings.LLM_NAME:
+ for model in OPENAI_MODELS:
+ if model.id == settings.LLM_NAME:
+ self.models[model.id] = model
+ return
+ for model in OPENAI_MODELS:
+ self.models[model.id] = model
+
+ def _add_azure_openai_models(self, settings):
+ from application.core.model_configs import AZURE_OPENAI_MODELS
+
+ if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
+ for model in AZURE_OPENAI_MODELS:
+ if model.id == settings.LLM_NAME:
+ self.models[model.id] = model
+ return
+ for model in AZURE_OPENAI_MODELS:
+ self.models[model.id] = model
+
+ def _add_anthropic_models(self, settings):
+ from application.core.model_configs import ANTHROPIC_MODELS
+
+ if settings.ANTHROPIC_API_KEY:
+ for model in ANTHROPIC_MODELS:
+ self.models[model.id] = model
+ return
+ if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
+ for model in ANTHROPIC_MODELS:
+ if model.id == settings.LLM_NAME:
+ self.models[model.id] = model
+ return
+ for model in ANTHROPIC_MODELS:
+ self.models[model.id] = model
+
+ def _add_google_models(self, settings):
+ from application.core.model_configs import GOOGLE_MODELS
+
+ if settings.GOOGLE_API_KEY:
+ for model in GOOGLE_MODELS:
+ self.models[model.id] = model
+ return
+ if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
+ for model in GOOGLE_MODELS:
+ if model.id == settings.LLM_NAME:
+ self.models[model.id] = model
+ return
+ for model in GOOGLE_MODELS:
+ self.models[model.id] = model
+
+ def _add_groq_models(self, settings):
+ from application.core.model_configs import GROQ_MODELS
+
+ if settings.GROQ_API_KEY:
+ for model in GROQ_MODELS:
+ self.models[model.id] = model
+ return
+ if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
+ for model in GROQ_MODELS:
+ if model.id == settings.LLM_NAME:
+ self.models[model.id] = model
+ return
+ for model in GROQ_MODELS:
+ self.models[model.id] = model
+
+ def _add_docsgpt_models(self, settings):
+ model_id = "docsgpt-local"
+ model = AvailableModel(
+ id=model_id,
+ provider=ModelProvider.DOCSGPT,
+ display_name="DocsGPT Model",
+ description="Local model",
+ capabilities=ModelCapabilities(
+ supports_tools=False,
+ supported_attachment_types=[],
+ ),
+ )
+ self.models[model_id] = model
+
+ def _add_huggingface_models(self, settings):
+ model_id = "huggingface-local"
+ model = AvailableModel(
+ id=model_id,
+ provider=ModelProvider.HUGGINGFACE,
+ display_name="Hugging Face Model",
+ description="Local Hugging Face model",
+ capabilities=ModelCapabilities(
+ supports_tools=False,
+ supported_attachment_types=[],
+ ),
+ )
+ self.models[model_id] = model
+
+ def get_model(self, model_id: str) -> Optional[AvailableModel]:
+ return self.models.get(model_id)
+
+ def get_all_models(self) -> List[AvailableModel]:
+ return list(self.models.values())
+
+ def get_enabled_models(self) -> List[AvailableModel]:
+ return [m for m in self.models.values() if m.enabled]
+
+ def model_exists(self, model_id: str) -> bool:
+ return model_id in self.models
diff --git a/application/core/model_utils.py b/application/core/model_utils.py
new file mode 100644
index 00000000..f24dbf47
--- /dev/null
+++ b/application/core/model_utils.py
@@ -0,0 +1,91 @@
+from typing import Any, Dict, Optional
+
+from application.core.model_settings import ModelRegistry
+
+
+def get_api_key_for_provider(provider: str) -> Optional[str]:
+ """Get the appropriate API key for a provider"""
+ from application.core.settings import settings
+
+ provider_key_map = {
+ "openai": settings.OPENAI_API_KEY,
+ "anthropic": settings.ANTHROPIC_API_KEY,
+ "google": settings.GOOGLE_API_KEY,
+ "groq": settings.GROQ_API_KEY,
+ "huggingface": settings.HUGGINGFACE_API_KEY,
+ "azure_openai": settings.API_KEY,
+ "docsgpt": None,
+ "llama.cpp": None,
+ }
+
+ provider_key = provider_key_map.get(provider)
+ if provider_key:
+ return provider_key
+ return settings.API_KEY
+
+
+def get_all_available_models() -> Dict[str, Dict[str, Any]]:
+ """Get all available models with metadata for API response"""
+ registry = ModelRegistry.get_instance()
+ return {model.id: model.to_dict() for model in registry.get_enabled_models()}
+
+
+def validate_model_id(model_id: str) -> bool:
+ """Check if a model ID exists in registry"""
+ registry = ModelRegistry.get_instance()
+ return registry.model_exists(model_id)
+
+
+def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]:
+ """Get capabilities for a specific model"""
+ registry = ModelRegistry.get_instance()
+ model = registry.get_model(model_id)
+ if model:
+ return {
+ "supported_attachment_types": model.capabilities.supported_attachment_types,
+ "supports_tools": model.capabilities.supports_tools,
+ "supports_structured_output": model.capabilities.supports_structured_output,
+ "context_window": model.capabilities.context_window,
+ }
+ return None
+
+
+def get_default_model_id() -> str:
+ """Get the system default model ID"""
+ registry = ModelRegistry.get_instance()
+ return registry.default_model_id
+
+
+def get_provider_from_model_id(model_id: str) -> Optional[str]:
+ """Get the provider name for a given model_id"""
+ registry = ModelRegistry.get_instance()
+ model = registry.get_model(model_id)
+ if model:
+ return model.provider.value
+ return None
+
+
+def get_token_limit(model_id: str) -> int:
+ """
+ Get context window (token limit) for a model.
+ Returns model's context_window or default 128000 if model not found.
+ """
+ from application.core.settings import settings
+
+ registry = ModelRegistry.get_instance()
+ model = registry.get_model(model_id)
+ if model:
+ return model.capabilities.context_window
+ return settings.DEFAULT_LLM_TOKEN_LIMIT
+
+
+def get_base_url_for_model(model_id: str) -> Optional[str]:
+ """
+ Get the custom base_url for a specific model if configured.
+ Returns None if no custom base_url is set.
+ """
+ registry = ModelRegistry.get_instance()
+ model = registry.get_model(model_id)
+ if model:
+ return model.base_url
+ return None
diff --git a/application/core/settings.py b/application/core/settings.py
index 22116a7c..ee7ffa05 100644
--- a/application/core/settings.py
+++ b/application/core/settings.py
@@ -22,15 +22,7 @@ class Settings(BaseSettings):
MONGO_DB_NAME: str = "docsgpt"
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
DEFAULT_MAX_HISTORY: int = 150
- LLM_TOKEN_LIMITS: dict = {
- "gpt-4o": 128000,
- "gpt-4o-mini": 128000,
- "gpt-4": 8192,
- "gpt-3.5-turbo": 4096,
- "claude-2": int(1e5),
- "gemini-2.5-flash": int(1e6),
- }
- DEFAULT_LLM_TOKEN_LIMIT: int = 128000
+ DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
RESERVED_TOKENS: dict = {
"system_prompt": 500,
"current_query": 500,
@@ -64,14 +56,22 @@ class Settings(BaseSettings):
)
# GitHub source
- GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
+ GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
# LLM Cache
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
API_URL: str = "http://localhost:7091" # backend url for celery worker
- API_KEY: Optional[str] = None # LLM api key
+ API_KEY: Optional[str] = None # LLM api key (used by LLM_PROVIDER)
+
+ # Provider-specific API keys (for multi-model support)
+ OPENAI_API_KEY: Optional[str] = None
+ ANTHROPIC_API_KEY: Optional[str] = None
+ GOOGLE_API_KEY: Optional[str] = None
+ GROQ_API_KEY: Optional[str] = None
+ HUGGINGFACE_API_KEY: Optional[str] = None
+
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
@@ -138,11 +138,12 @@ class Settings(BaseSettings):
# Encryption settings
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
- TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
+ TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
ELEVENLABS_API_KEY: Optional[str] = None
# Tool pre-fetch settings
ENABLE_TOOL_PREFETCH: bool = True
+
path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py
index b55dd855..4d26f925 100644
--- a/application/llm/anthropic.py
+++ b/application/llm/anthropic.py
@@ -1,30 +1,41 @@
-from application.llm.base import BaseLLM
+from anthropic import AI_PROMPT, Anthropic, HUMAN_PROMPT
+
from application.core.settings import settings
+from application.llm.base import BaseLLM
class AnthropicLLM(BaseLLM):
- def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
- from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
+ def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.api_key = (
- api_key or settings.ANTHROPIC_API_KEY
- ) # If not provided, use a default from settings
+ self.api_key = api_key or settings.ANTHROPIC_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
- self.anthropic = Anthropic(api_key=self.api_key)
+
+ # Use custom base_url if provided
+ if base_url:
+ self.anthropic = Anthropic(api_key=self.api_key, base_url=base_url)
+ else:
+ self.anthropic = Anthropic(api_key=self.api_key)
+
self.HUMAN_PROMPT = HUMAN_PROMPT
self.AI_PROMPT = AI_PROMPT
def _raw_gen(
- self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs
+ self,
+ baseself,
+ model,
+ messages,
+ stream=False,
+ tools=None,
+ max_tokens=300,
+ **kwargs,
):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
if stream:
return self.gen_stream(model, prompt, stream, max_tokens, **kwargs)
-
completion = self.anthropic.completions.create(
model=model,
max_tokens_to_sample=max_tokens,
@@ -34,7 +45,14 @@ class AnthropicLLM(BaseLLM):
return completion.completion
def _raw_gen_stream(
- self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs
+ self,
+ baseself,
+ model,
+ messages,
+ stream=True,
+ tools=None,
+ max_tokens=300,
+ **kwargs,
):
context = messages[0]["content"]
user_question = messages[-1]["content"]
@@ -50,5 +68,5 @@ class AnthropicLLM(BaseLLM):
for completion in stream_response:
yield completion.completion
finally:
- if hasattr(stream_response, 'close'):
+ if hasattr(stream_response, "close"):
stream_response.close()
diff --git a/application/llm/base.py b/application/llm/base.py
index c16ec99e..e8ee78eb 100644
--- a/application/llm/base.py
+++ b/application/llm/base.py
@@ -13,30 +13,32 @@ class BaseLLM(ABC):
def __init__(
self,
decoded_token=None,
+ model_id=None,
+ base_url=None,
):
self.decoded_token = decoded_token
+ self.model_id = model_id
+ self.base_url = base_url
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
- self.fallback_provider = settings.FALLBACK_LLM_PROVIDER
- self.fallback_model_name = settings.FALLBACK_LLM_NAME
- self.fallback_llm_api_key = settings.FALLBACK_LLM_API_KEY
self._fallback_llm = None
+ self._fallback_sequence_index = 0
@property
def fallback_llm(self):
- """Lazy-loaded fallback LLM instance."""
- if (
- self._fallback_llm is None
- and self.fallback_provider
- and self.fallback_model_name
- ):
+ """Lazy-loaded fallback LLM from FALLBACK_* settings."""
+ if self._fallback_llm is None and settings.FALLBACK_LLM_PROVIDER:
try:
from application.llm.llm_creator import LLMCreator
self._fallback_llm = LLMCreator.create_llm(
- self.fallback_provider,
- self.fallback_llm_api_key,
- None,
- self.decoded_token,
+ settings.FALLBACK_LLM_PROVIDER,
+ api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
+ user_api_key=None,
+ decoded_token=self.decoded_token,
+ model_id=settings.FALLBACK_LLM_NAME,
+ )
+ logger.info(
+ f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
)
except Exception as e:
logger.error(
@@ -54,7 +56,7 @@ class BaseLLM(ABC):
self, method_name: str, decorators: list, *args, **kwargs
):
"""
- Unified method execution with fallback support.
+ Execute method with fallback support.
Args:
method_name: Name of the raw method ('_raw_gen' or '_raw_gen_stream')
@@ -73,10 +75,10 @@ class BaseLLM(ABC):
return decorated_method()
except Exception as e:
if not self.fallback_llm:
- logger.error(f"Primary LLM failed and no fallback available: {str(e)}")
+ logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
raise
logger.warning(
- f"Falling back to {self.fallback_provider}/{self.fallback_model_name}. Error: {str(e)}"
+ f"Primary LLM failed. Falling back to {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}. Error: {str(e)}"
)
fallback_method = getattr(
diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py
index 3572db40..44a479ae 100644
--- a/application/llm/docsgpt_provider.py
+++ b/application/llm/docsgpt_provider.py
@@ -1,5 +1,7 @@
import json
+from openai import OpenAI
+
from application.core.settings import settings
from application.llm.base import BaseLLM
@@ -7,12 +9,11 @@ from application.llm.base import BaseLLM
class DocsGPTAPILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
- from openai import OpenAI
super().__init__(*args, **kwargs)
- self.client = OpenAI(api_key="sk-docsgpt-public", base_url="https://oai.arc53.com")
+ self.api_key = "sk-docsgpt-public"
+ self.client = OpenAI(api_key=self.api_key, base_url="https://oai.arc53.com")
self.user_api_key = user_api_key
- self.api_key = api_key
def _clean_messages_openai(self, messages):
cleaned_messages = []
@@ -22,7 +23,6 @@ class DocsGPTAPILLM(BaseLLM):
if role == "model":
role = "assistant"
-
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
@@ -69,7 +69,6 @@ class DocsGPTAPILLM(BaseLLM):
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
-
return cleaned_messages
def _raw_gen(
@@ -121,7 +120,6 @@ class DocsGPTAPILLM(BaseLLM):
response = self.client.chat.completions.create(
model="docsgpt", messages=messages, stream=stream, **kwargs
)
-
try:
for line in response:
if (
@@ -133,8 +131,8 @@ class DocsGPTAPILLM(BaseLLM):
elif len(line.choices) > 0:
yield line.choices[0]
finally:
- if hasattr(response, 'close'):
+ if hasattr(response, "close"):
response.close()
def _supports_tools(self):
- return True
\ No newline at end of file
+ return True
diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py
index 47be51cd..9c58a3e1 100644
--- a/application/llm/google_ai.py
+++ b/application/llm/google_ai.py
@@ -13,8 +13,9 @@ from application.storage.storage_creator import StorageCreator
class GoogleLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.api_key = api_key
+ self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
+
self.client = genai.Client(api_key=self.api_key)
self.storage = StorageCreator.get_storage()
@@ -47,21 +48,19 @@ class GoogleLLM(BaseLLM):
"""
if not attachments:
return messages
-
prepared_messages = messages.copy()
# Find the user message to attach files to the last one
+
user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user":
user_message_index = i
break
-
if user_message_index is None:
user_message = {"role": "user", "content": []}
prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1
-
if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [
@@ -69,7 +68,6 @@ class GoogleLLM(BaseLLM):
]
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = []
-
files = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
@@ -92,11 +90,9 @@ class GoogleLLM(BaseLLM):
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
}
)
-
if files:
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
prepared_messages[user_message_index]["content"].append({"files": files})
-
return prepared_messages
def _upload_file_to_google(self, attachment):
@@ -111,14 +107,11 @@ class GoogleLLM(BaseLLM):
"""
if "google_file_uri" in attachment:
return attachment["google_file_uri"]
-
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
-
if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
-
try:
file_uri = self.storage.process_file(
file_path,
@@ -136,7 +129,6 @@ class GoogleLLM(BaseLLM):
attachments_collection.update_one(
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
)
-
return file_uri
except Exception as e:
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
@@ -153,7 +145,6 @@ class GoogleLLM(BaseLLM):
role = "model"
elif role == "tool":
role = "model"
-
parts = []
if role and content is not None:
if isinstance(content, str):
@@ -164,6 +155,7 @@ class GoogleLLM(BaseLLM):
parts.append(types.Part.from_text(text=item["text"]))
elif "function_call" in item:
# Remove null values from args to avoid API errors
+
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
@@ -194,10 +186,8 @@ class GoogleLLM(BaseLLM):
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
-
if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
-
return cleaned_messages
def _clean_schema(self, schema_obj):
@@ -233,8 +223,8 @@ class GoogleLLM(BaseLLM):
cleaned[key] = [self._clean_schema(item) for item in value]
else:
cleaned[key] = value
-
# Validate that required properties actually exist in properties
+
if "required" in cleaned and "properties" in cleaned:
valid_required = []
properties_keys = set(cleaned["properties"].keys())
@@ -247,7 +237,6 @@ class GoogleLLM(BaseLLM):
cleaned.pop("required", None)
elif "required" in cleaned and "properties" not in cleaned:
cleaned.pop("required", None)
-
return cleaned
def _clean_tools_format(self, tools_list):
@@ -263,7 +252,6 @@ class GoogleLLM(BaseLLM):
cleaned_properties = {}
for k, v in properties.items():
cleaned_properties[k] = self._clean_schema(v)
-
genai_function = dict(
name=function["name"],
description=function["description"],
@@ -282,10 +270,8 @@ class GoogleLLM(BaseLLM):
name=function["name"],
description=function["description"],
)
-
genai_tool = types.Tool(function_declarations=[genai_function])
genai_tools.append(genai_tool)
-
return genai_tools
def _raw_gen(
@@ -307,16 +293,14 @@ class GoogleLLM(BaseLLM):
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
-
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
-
# Add response schema for structured output if provided
+
if response_schema:
config.response_schema = response_schema
config.response_mime_type = "application/json"
-
response = client.models.generate_content(
model=model,
contents=messages,
@@ -347,17 +331,16 @@ class GoogleLLM(BaseLLM):
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
-
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
-
# Add response schema for structured output if provided
+
if response_schema:
config.response_schema = response_schema
config.response_mime_type = "application/json"
-
# Check if we have both tools and file attachments
+
has_attachments = False
for message in messages:
for part in message.parts:
@@ -366,7 +349,6 @@ class GoogleLLM(BaseLLM):
break
if has_attachments:
break
-
logging.info(
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
)
@@ -405,7 +387,6 @@ class GoogleLLM(BaseLLM):
"""Convert JSON schema to Google AI structured output format."""
if not json_schema:
return None
-
type_map = {
"object": "OBJECT",
"array": "ARRAY",
@@ -418,12 +399,10 @@ class GoogleLLM(BaseLLM):
def convert(schema):
if not isinstance(schema, dict):
return schema
-
result = {}
schema_type = schema.get("type")
if schema_type:
result["type"] = type_map.get(schema_type.lower(), schema_type.upper())
-
for key in [
"description",
"nullable",
@@ -435,7 +414,6 @@ class GoogleLLM(BaseLLM):
]:
if key in schema:
result[key] = schema[key]
-
if "format" in schema:
format_value = schema["format"]
if schema_type == "string":
@@ -445,21 +423,17 @@ class GoogleLLM(BaseLLM):
result["format"] = format_value
else:
result["format"] = format_value
-
if "properties" in schema:
result["properties"] = {
k: convert(v) for k, v in schema["properties"].items()
}
if "propertyOrdering" not in result and result.get("type") == "OBJECT":
result["propertyOrdering"] = list(result["properties"].keys())
-
if "items" in schema:
result["items"] = convert(schema["items"])
-
for field in ["anyOf", "oneOf", "allOf"]:
if field in schema:
result[field] = [convert(s) for s in schema[field]]
-
return result
try:
diff --git a/application/llm/groq.py b/application/llm/groq.py
index 282d7f47..c2ae40ee 100644
--- a/application/llm/groq.py
+++ b/application/llm/groq.py
@@ -1,13 +1,18 @@
-from application.llm.base import BaseLLM
from openai import OpenAI
+from application.core.settings import settings
+from application.llm.base import BaseLLM
+
class GroqLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
+
super().__init__(*args, **kwargs)
- self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1")
- self.api_key = api_key
+ self.api_key = api_key or settings.GROQ_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
+ self.client = OpenAI(
+ api_key=self.api_key, base_url="https://api.groq.com/openai/v1"
+ )
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
if tools:
diff --git a/application/llm/handlers/base.py b/application/llm/handlers/base.py
index 96ed4c00..920caf65 100644
--- a/application/llm/handlers/base.py
+++ b/application/llm/handlers/base.py
@@ -282,7 +282,7 @@ class LLMHandler(ABC):
messages = e.value
break
response = agent.llm.gen(
- model=agent.gpt_model, messages=messages, tools=agent.tools
+ model=agent.model_id, messages=messages, tools=agent.tools
)
parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm))
@@ -337,7 +337,7 @@ class LLMHandler(ABC):
tool_calls = {}
response = agent.llm.gen_stream(
- model=agent.gpt_model, messages=messages, tools=agent.tools
+ model=agent.model_id, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py
index 3ed23854..21d653b9 100644
--- a/application/llm/llm_creator.py
+++ b/application/llm/llm_creator.py
@@ -1,13 +1,17 @@
-from application.llm.groq import GroqLLM
-from application.llm.openai import OpenAILLM, AzureOpenAILLM
-from application.llm.sagemaker import SagemakerAPILLM
-from application.llm.huggingface import HuggingFaceLLM
-from application.llm.llama_cpp import LlamaCpp
+import logging
+
from application.llm.anthropic import AnthropicLLM
from application.llm.docsgpt_provider import DocsGPTAPILLM
-from application.llm.premai import PremAILLM
from application.llm.google_ai import GoogleLLM
+from application.llm.groq import GroqLLM
+from application.llm.huggingface import HuggingFaceLLM
+from application.llm.llama_cpp import LlamaCpp
from application.llm.novita import NovitaLLM
+from application.llm.openai import AzureOpenAILLM, OpenAILLM
+from application.llm.premai import PremAILLM
+from application.llm.sagemaker import SagemakerAPILLM
+
+logger = logging.getLogger(__name__)
class LLMCreator:
@@ -26,10 +30,26 @@ class LLMCreator:
}
@classmethod
- def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs):
+ def create_llm(
+ cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs
+ ):
+ from application.core.model_utils import get_base_url_for_model
+
llm_class = cls.llms.get(type.lower())
if not llm_class:
raise ValueError(f"No LLM class found for type {type}")
+
+ # Extract base_url from model configuration if model_id is provided
+ base_url = None
+ if model_id:
+ base_url = get_base_url_for_model(model_id)
+
return llm_class(
- api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs
+ api_key,
+ user_api_key,
+ decoded_token=decoded_token,
+ model_id=model_id,
+ base_url=base_url,
+ *args,
+ **kwargs,
)
diff --git a/application/llm/openai.py b/application/llm/openai.py
index de24e2c5..beab465b 100644
--- a/application/llm/openai.py
+++ b/application/llm/openai.py
@@ -2,6 +2,8 @@ import base64
import json
import logging
+from openai import OpenAI
+
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.storage.storage_creator import StorageCreator
@@ -9,20 +11,25 @@ from application.storage.storage_creator import StorageCreator
class OpenAILLM(BaseLLM):
- def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
- from openai import OpenAI
+ def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(*args, **kwargs)
- if (
+ self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
+ self.user_api_key = user_api_key
+
+ # Priority: 1) Parameter base_url, 2) Settings OPENAI_BASE_URL, 3) Default
+ effective_base_url = None
+ if base_url and isinstance(base_url, str) and base_url.strip():
+ effective_base_url = base_url
+ elif (
isinstance(settings.OPENAI_BASE_URL, str)
and settings.OPENAI_BASE_URL.strip()
):
- self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
+ effective_base_url = settings.OPENAI_BASE_URL
else:
- DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1"
- self.client = OpenAI(api_key=api_key, base_url=DEFAULT_OPENAI_API_BASE)
- self.api_key = api_key
- self.user_api_key = user_api_key
+ effective_base_url = "https://api.openai.com/v1"
+
+ self.client = OpenAI(api_key=self.api_key, base_url=effective_base_url)
self.storage = StorageCreator.get_storage()
def _clean_messages_openai(self, messages):
@@ -33,7 +40,6 @@ class OpenAILLM(BaseLLM):
if role == "model":
role = "assistant"
-
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
@@ -107,7 +113,6 @@ class OpenAILLM(BaseLLM):
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
-
return cleaned_messages
def _raw_gen(
@@ -132,10 +137,8 @@ class OpenAILLM(BaseLLM):
if tools:
request_params["tools"] = tools
-
if response_format:
request_params["response_format"] = response_format
-
response = self.client.chat.completions.create(**request_params)
if tools:
@@ -165,10 +168,8 @@ class OpenAILLM(BaseLLM):
if tools:
request_params["tools"] = tools
-
if response_format:
request_params["response_format"] = response_format
-
response = self.client.chat.completions.create(**request_params)
try:
@@ -194,7 +195,6 @@ class OpenAILLM(BaseLLM):
def prepare_structured_output_format(self, json_schema):
if not json_schema:
return None
-
try:
def add_additional_properties_false(schema_obj):
@@ -204,11 +204,11 @@ class OpenAILLM(BaseLLM):
if schema_copy.get("type") == "object":
schema_copy["additionalProperties"] = False
# Ensure 'required' includes all properties for OpenAI strict mode
+
if "properties" in schema_copy:
schema_copy["required"] = list(
schema_copy["properties"].keys()
)
-
for key, value in schema_copy.items():
if key == "properties" and isinstance(value, dict):
schema_copy[key] = {
@@ -224,7 +224,6 @@ class OpenAILLM(BaseLLM):
add_additional_properties_false(sub_schema)
for sub_schema in value
]
-
return schema_copy
return schema_obj
@@ -243,7 +242,6 @@ class OpenAILLM(BaseLLM):
}
return result
-
except Exception as e:
logging.error(f"Error preparing structured output format: {e}")
return None
@@ -277,21 +275,19 @@ class OpenAILLM(BaseLLM):
"""
if not attachments:
return messages
-
prepared_messages = messages.copy()
# Find the user message to attach file_id to the last one
+
user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user":
user_message_index = i
break
-
if user_message_index is None:
user_message = {"role": "user", "content": []}
prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1
-
if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [
@@ -299,7 +295,6 @@ class OpenAILLM(BaseLLM):
]
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = []
-
for attachment in attachments:
mime_type = attachment.get("mime_type")
@@ -326,6 +321,7 @@ class OpenAILLM(BaseLLM):
}
)
# Handle PDFs using the file API
+
elif mime_type == "application/pdf":
try:
file_id = self._upload_file_to_openai(attachment)
@@ -341,7 +337,6 @@ class OpenAILLM(BaseLLM):
"text": f"File content:\n\n{attachment['content']}",
}
)
-
return prepared_messages
def _get_base64_image(self, attachment):
@@ -357,7 +352,6 @@ class OpenAILLM(BaseLLM):
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
-
try:
with self.storage.get_file(file_path) as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
@@ -381,12 +375,10 @@ class OpenAILLM(BaseLLM):
if "openai_file_id" in attachment:
return attachment["openai_file_id"]
-
file_path = attachment.get("path")
if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
-
try:
file_id = self.storage.process_file(
file_path,
@@ -404,7 +396,6 @@ class OpenAILLM(BaseLLM):
attachments_collection.update_one(
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
)
-
return file_id
except Exception as e:
logging.error(f"Error uploading file to OpenAI: {e}", exc_info=True)
diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py
index f0428a26..f0468a18 100644
--- a/application/retriever/classic_rag.py
+++ b/application/retriever/classic_rag.py
@@ -16,7 +16,7 @@ class ClassicRAG(BaseRetriever):
prompt="",
chunks=2,
doc_token_limit=50000,
- gpt_model="docsgpt",
+ model_id="docsgpt-local",
user_api_key=None,
llm_name=settings.LLM_PROVIDER,
api_key=settings.API_KEY,
@@ -40,7 +40,7 @@ class ClassicRAG(BaseRetriever):
f"ClassicRAG initialized with chunks={self.chunks}, user_api_key={user_identifier}, "
f"sources={'active_docs' in source and source['active_docs'] is not None}"
)
- self.gpt_model = gpt_model
+ self.model_id = model_id
self.doc_token_limit = doc_token_limit
self.user_api_key = user_api_key
self.llm_name = llm_name
@@ -100,7 +100,7 @@ class ClassicRAG(BaseRetriever):
]
try:
- rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages)
+ rephrased_query = self.llm.gen(model=self.model_id, messages=messages)
print(f"Rephrased query: {rephrased_query}")
return rephrased_query if rephrased_query else self.original_question
except Exception as e:
diff --git a/application/utils.py b/application/utils.py
index 06eaf495..89b884f0 100644
--- a/application/utils.py
+++ b/application/utils.py
@@ -7,6 +7,8 @@ import tiktoken
from flask import jsonify, make_response
from werkzeug.utils import secure_filename
+from application.core.model_utils import get_token_limit
+
from application.core.settings import settings
@@ -75,11 +77,9 @@ def count_tokens_docs(docs):
def calculate_doc_token_budget(
- gpt_model: str = "gpt-4o", history_token_limit: int = 2000
+ model_id: str = "gpt-4o", history_token_limit: int = 2000
) -> int:
- total_context = settings.LLM_TOKEN_LIMITS.get(
- gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT
- )
+ total_context = get_token_limit(model_id)
reserved = sum(settings.RESERVED_TOKENS.values())
doc_budget = total_context - history_token_limit - reserved
return max(doc_budget, 1000)
@@ -144,16 +144,13 @@ def get_hash(data):
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
-def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
+def limit_chat_history(history, max_token_limit=None, model_id="docsgpt-local"):
"""Limit chat history to fit within token limit."""
- from application.core.settings import settings
-
+ model_token_limit = get_token_limit(model_id)
max_token_limit = (
max_token_limit
- if max_token_limit
- and max_token_limit
- < settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT)
- else settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT)
+ if max_token_limit and max_token_limit < model_token_limit
+ else model_token_limit
)
if not history:
@@ -205,37 +202,44 @@ def clean_text_for_tts(text: str) -> str:
clean text for Text-to-Speech processing.
"""
# Handle code blocks and links
- text = re.sub(r'```mermaid[\s\S]*?```', ' flowchart, ', text) ## ```mermaid...```
- text = re.sub(r'```[\s\S]*?```', ' code block, ', text) ## ```code```
- text = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', text) ## [text](url)
- text = re.sub(r'!\[([^\]]*)\]\([^\)]+\)', '', text) ## 
+
+ text = re.sub(r"```mermaid[\s\S]*?```", " flowchart, ", text) ## ```mermaid...```
+ text = re.sub(r"```[\s\S]*?```", " code block, ", text) ## ```code```
+ text = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", text) ## [text](url)
+ text = re.sub(r"!\[([^\]]*)\]\([^\)]+\)", "", text) ## 
# Remove markdown formatting
- text = re.sub(r'`([^`]+)`', r'\1', text) ## `code`
- text = re.sub(r'\{([^}]*)\}', r' \1 ', text) ## {text}
- text = re.sub(r'[{}]', ' ', text) ## unmatched {}
- text = re.sub(r'\[([^\]]+)\]', r' \1 ', text) ## [text]
- text = re.sub(r'[\[\]]', ' ', text) ## unmatched []
- text = re.sub(r'(\*\*|__)(.*?)\1', r'\2', text) ## **bold** __bold__
- text = re.sub(r'(\*|_)(.*?)\1', r'\2', text) ## *italic* _italic_
- text = re.sub(r'^#{1,6}\s+', '', text, flags=re.MULTILINE) ## # headers
- text = re.sub(r'^>\s+', '', text, flags=re.MULTILINE) ## > blockquotes
- text = re.sub(r'^[\s]*[-\*\+]\s+', '', text, flags=re.MULTILINE) ## - * + lists
- text = re.sub(r'^[\s]*\d+\.\s+', '', text, flags=re.MULTILINE) ## 1. numbered lists
- text = re.sub(r'^[\*\-_]{3,}\s*$', '', text, flags=re.MULTILINE) ## --- *** ___ rules
- text = re.sub(r'<[^>]*>', '', text) ## tags
- #Remove non-ASCII (emojis, special Unicode)
- text = re.sub(r'[^\x20-\x7E\n\r\t]', '', text)
+ text = re.sub(r"`([^`]+)`", r"\1", text) ## `code`
+ text = re.sub(r"\{([^}]*)\}", r" \1 ", text) ## {text}
+ text = re.sub(r"[{}]", " ", text) ## unmatched {}
+ text = re.sub(r"\[([^\]]+)\]", r" \1 ", text) ## [text]
+ text = re.sub(r"[\[\]]", " ", text) ## unmatched []
+ text = re.sub(r"(\*\*|__)(.*?)\1", r"\2", text) ## **bold** __bold__
+ text = re.sub(r"(\*|_)(.*?)\1", r"\2", text) ## *italic* _italic_
+ text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE) ## # headers
+ text = re.sub(r"^>\s+", "", text, flags=re.MULTILINE) ## > blockquotes
+ text = re.sub(r"^[\s]*[-\*\+]\s+", "", text, flags=re.MULTILINE) ## - * + lists
+ text = re.sub(r"^[\s]*\d+\.\s+", "", text, flags=re.MULTILINE) ## 1. numbered lists
+ text = re.sub(
+ r"^[\*\-_]{3,}\s*$", "", text, flags=re.MULTILINE
+ ) ## --- *** ___ rules
+ text = re.sub(r"<[^>]*>", "", text) ## tags
- #Replace special sequences
- text = re.sub(r'-->', ', ', text) ## -->
- text = re.sub(r'<--', ', ', text) ## <--
- text = re.sub(r'=>', ', ', text) ## =>
- text = re.sub(r'::', ' ', text) ## ::
+ # Remove non-ASCII (emojis, special Unicode)
- #Normalize whitespace
- text = re.sub(r'\s+', ' ', text)
+ text = re.sub(r"[^\x20-\x7E\n\r\t]", "", text)
+
+ # Replace special sequences
+
+ text = re.sub(r"-->", ", ", text) ## -->
+ text = re.sub(r"<--", ", ", text) ## <--
+ text = re.sub(r"=>", ", ", text) ## =>
+ text = re.sub(r"::", " ", text) ## ::
+
+ # Normalize whitespace
+
+ text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
diff --git a/application/worker.py b/application/worker.py
index f17e1537..3f957527 100755
--- a/application/worker.py
+++ b/application/worker.py
@@ -165,7 +165,7 @@ def run_agent_logic(agent_config, input_data):
agent_type,
endpoint="webhook",
llm_name=settings.LLM_PROVIDER,
- gpt_model=settings.LLM_NAME,
+ model_id=settings.LLM_NAME,
api_key=settings.API_KEY,
user_api_key=user_api_key,
prompt=prompt,
@@ -180,7 +180,7 @@ def run_agent_logic(agent_config, input_data):
prompt=prompt,
chunks=chunks,
token_limit=settings.DEFAULT_MAX_HISTORY,
- gpt_model=settings.LLM_NAME,
+ model_id=settings.LLM_NAME,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
diff --git a/frontend/src/Hero.tsx b/frontend/src/Hero.tsx
index 9b17c10f..8b08430b 100644
--- a/frontend/src/Hero.tsx
+++ b/frontend/src/Hero.tsx
@@ -1,6 +1,8 @@
-import DocsGPT3 from './assets/cute_docsgpt3.svg';
import { useTranslation } from 'react-i18next';
+import DocsGPT3 from './assets/cute_docsgpt3.svg';
+import DropdownModel from './components/DropdownModel';
+
export default function Hero({
handleQuestion,
}: {
@@ -26,6 +28,10 @@ export default function Hero({
DocsGPT
+ {/* Model Selector */}
+