diff --git a/application/agents/agent_creator.py b/application/agents/agent_creator.py index bf37d4ec..44e89552 100644 --- a/application/agents/agent_creator.py +++ b/application/agents/agent_creator.py @@ -1,5 +1,8 @@ from application.agents.classic_agent import ClassicAgent from application.agents.react_agent import ReActAgent +import logging + +logger = logging.getLogger(__name__) class AgentCreator: @@ -13,4 +16,5 @@ class AgentCreator: agent_class = cls.agents.get(type.lower()) if not agent_class: raise ValueError(f"No agent class found for type {type}") + return agent_class(*args, **kwargs) diff --git a/application/agents/base.py b/application/agents/base.py index 27428fc3..dbf15a1f 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -21,7 +21,7 @@ class BaseAgent(ABC): self, endpoint: str, llm_name: str, - gpt_model: str, + model_id: str, api_key: str, user_api_key: Optional[str] = None, prompt: str = "", @@ -37,7 +37,7 @@ class BaseAgent(ABC): ): self.endpoint = endpoint self.llm_name = llm_name - self.gpt_model = gpt_model + self.model_id = model_id self.api_key = api_key self.user_api_key = user_api_key self.prompt = prompt @@ -52,6 +52,7 @@ class BaseAgent(ABC): api_key=api_key, user_api_key=user_api_key, decoded_token=decoded_token, + model_id=model_id, ) self.retrieved_docs = retrieved_docs or [] self.llm_handler = LLMHandlerCreator.create_handler( @@ -316,7 +317,7 @@ class BaseAgent(ABC): return messages def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None): - gen_kwargs = {"model": self.gpt_model, "messages": messages} + gen_kwargs = {"model": self.model_id, "messages": messages} if ( hasattr(self.llm, "_supports_tools") diff --git a/application/agents/react_agent.py b/application/agents/react_agent.py index 49dd29d8..116fa4aa 100644 --- a/application/agents/react_agent.py +++ b/application/agents/react_agent.py @@ -86,7 +86,7 @@ class ReActAgent(BaseAgent): messages = [{"role": "user", "content": plan_prompt}] plan_stream = self.llm.gen_stream( - model=self.gpt_model, + model=self.model_id, messages=messages, tools=self.tools if self.tools else None, ) @@ -151,7 +151,7 @@ class ReActAgent(BaseAgent): messages = [{"role": "user", "content": final_prompt}] final_stream = self.llm.gen_stream( - model=self.gpt_model, messages=messages, tools=None + model=self.model_id, messages=messages, tools=None ) if log_context: diff --git a/application/api/answer/routes/answer.py b/application/api/answer/routes/answer.py index 87d80059..bc7ec58c 100644 --- a/application/api/answer/routes/answer.py +++ b/application/api/answer/routes/answer.py @@ -54,6 +54,10 @@ class AnswerResource(Resource, BaseAnswerResource): default=True, description="Whether to save the conversation", ), + "model_id": fields.String( + required=False, + description="Model ID to use for this request", + ), "passthrough": fields.Raw( required=False, description="Dynamic parameters to inject into prompt template", @@ -97,6 +101,7 @@ class AnswerResource(Resource, BaseAnswerResource): isNoneDoc=data.get("isNoneDoc"), index=None, should_save_conversation=data.get("save_conversation", True), + model_id=processor.model_id, ) stream_result = self.process_response_stream(stream) diff --git a/application/api/answer/routes/base.py b/application/api/answer/routes/base.py index 43e83ed2..aefb66c6 100644 --- a/application/api/answer/routes/base.py +++ b/application/api/answer/routes/base.py @@ -7,11 +7,16 @@ from flask import jsonify, make_response, Response from flask_restx import Namespace from application.api.answer.services.conversation_service import ConversationService +from application.core.model_utils import ( + get_api_key_for_provider, + get_default_model_id, + get_provider_from_model_id, +) from application.core.mongo_db import MongoDB from application.core.settings import settings from application.llm.llm_creator import LLMCreator -from application.utils import check_required_fields, get_gpt_model +from application.utils import check_required_fields logger = logging.getLogger(__name__) @@ -27,7 +32,7 @@ class BaseAnswerResource: db = mongo[settings.MONGO_DB_NAME] self.db = db self.user_logs_collection = db["user_logs"] - self.gpt_model = get_gpt_model() + self.default_model_id = get_default_model_id() self.conversation_service = ConversationService() def validate_request( @@ -54,7 +59,6 @@ class BaseAnswerResource: api_key = agent_config.get("user_api_key") if not api_key: return None - agents_collection = self.db["agents"] agent = agents_collection.find_one({"key": api_key}) @@ -62,7 +66,6 @@ class BaseAnswerResource: return make_response( jsonify({"success": False, "message": "Invalid API key."}), 401 ) - limited_token_mode_raw = agent.get("limited_token_mode", False) limited_request_mode_raw = agent.get("limited_request_mode", False) @@ -110,15 +113,12 @@ class BaseAnswerResource: daily_token_usage = token_result[0]["total_tokens"] if token_result else 0 else: daily_token_usage = 0 - if limited_request_mode: daily_request_usage = token_usage_collection.count_documents(match_query) else: daily_request_usage = 0 - if not limited_token_mode and not limited_request_mode: return None - token_exceeded = ( limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit ) @@ -138,7 +138,6 @@ class BaseAnswerResource: ), 429, ) - return None def complete_stream( @@ -155,6 +154,7 @@ class BaseAnswerResource: agent_id: Optional[str] = None, is_shared_usage: bool = False, shared_token: Optional[str] = None, + model_id: Optional[str] = None, ) -> Generator[str, None, None]: """ Generator function that streams the complete conversation response. @@ -173,6 +173,7 @@ class BaseAnswerResource: agent_id: ID of agent used is_shared_usage: Flag for shared agent usage shared_token: Token for shared agent + model_id: Model ID used for the request retrieved_docs: Pre-fetched documents for sources (optional) Yields: @@ -220,7 +221,6 @@ class BaseAnswerResource: elif "type" in line: data = json.dumps(line) yield f"data: {data}\n\n" - if is_structured and structured_chunks: structured_data = { "type": "structured_answer", @@ -230,15 +230,22 @@ class BaseAnswerResource: } data = json.dumps(structured_data) yield f"data: {data}\n\n" - if isNoneDoc: for doc in source_log_docs: doc["source"] = "None" + provider = ( + get_provider_from_model_id(model_id) + if model_id + else settings.LLM_PROVIDER + ) + system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER) + llm = LLMCreator.create_llm( - settings.LLM_PROVIDER, - api_key=settings.API_KEY, + provider or settings.LLM_PROVIDER, + api_key=system_api_key, user_api_key=user_api_key, decoded_token=decoded_token, + model_id=model_id, ) if should_save_conversation: @@ -250,7 +257,7 @@ class BaseAnswerResource: source_log_docs, tool_calls, llm, - self.gpt_model, + model_id or self.default_model_id, decoded_token, index=index, api_key=user_api_key, @@ -280,12 +287,11 @@ class BaseAnswerResource: log_data["structured_output"] = True if schema_info: log_data["schema"] = schema_info - # Clean up text fields to be no longer than 10000 characters + for key, value in log_data.items(): if isinstance(value, str) and len(value) > 10000: log_data[key] = value[:10000] - self.user_logs_collection.insert_one(log_data) data = json.dumps({"type": "end"}) @@ -293,6 +299,7 @@ class BaseAnswerResource: except GeneratorExit: logger.info(f"Stream aborted by client for question: {question[:50]}... ") # Save partial response + if should_save_conversation and response_full: try: if isNoneDoc: @@ -312,7 +319,7 @@ class BaseAnswerResource: source_log_docs, tool_calls, llm, - self.gpt_model, + model_id or self.default_model_id, decoded_token, index=index, api_key=user_api_key, @@ -369,7 +376,7 @@ class BaseAnswerResource: thought = event["thought"] elif event["type"] == "error": logger.error(f"Error from stream: {event['error']}") - return None, None, None, None, event["error"] + return None, None, None, None, event["error"], None elif event["type"] == "end": stream_ended = True except (json.JSONDecodeError, KeyError) as e: @@ -377,8 +384,7 @@ class BaseAnswerResource: continue if not stream_ended: logger.error("Stream ended unexpectedly without an 'end' event.") - return None, None, None, None, "Stream ended unexpectedly" - + return None, None, None, None, "Stream ended unexpectedly", None result = ( conversation_id, response_full, @@ -390,7 +396,6 @@ class BaseAnswerResource: if is_structured: result = result + ({"structured": True, "schema": schema_info},) - return result def error_stream_generate(self, err_response): diff --git a/application/api/answer/routes/stream.py b/application/api/answer/routes/stream.py index 92e41c14..b2827a93 100644 --- a/application/api/answer/routes/stream.py +++ b/application/api/answer/routes/stream.py @@ -57,6 +57,10 @@ class StreamResource(Resource, BaseAnswerResource): default=True, description="Whether to save the conversation", ), + "model_id": fields.String( + required=False, + description="Model ID to use for this request", + ), "attachments": fields.List( fields.String, required=False, description="List of attachment IDs" ), @@ -101,6 +105,7 @@ class StreamResource(Resource, BaseAnswerResource): agent_id=data.get("agent_id"), is_shared_usage=processor.is_shared_usage, shared_token=processor.shared_token, + model_id=processor.model_id, ), mimetype="text/event-stream", ) diff --git a/application/api/answer/services/conversation_service.py b/application/api/answer/services/conversation_service.py index eca842d6..0e98983e 100644 --- a/application/api/answer/services/conversation_service.py +++ b/application/api/answer/services/conversation_service.py @@ -52,7 +52,7 @@ class ConversationService: sources: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]], llm: Any, - gpt_model: str, + model_id: str, decoded_token: Dict[str, Any], index: Optional[int] = None, api_key: Optional[str] = None, @@ -66,7 +66,7 @@ class ConversationService: if not user_id: raise ValueError("User ID not found in token") current_time = datetime.now(timezone.utc) - + # clean up in sources array such that we save max 1k characters for text part for source in sources: if "text" in source and isinstance(source["text"], str): @@ -90,6 +90,7 @@ class ConversationService: f"queries.{index}.tool_calls": tool_calls, f"queries.{index}.timestamp": current_time, f"queries.{index}.attachments": attachment_ids, + f"queries.{index}.model_id": model_id, } }, ) @@ -120,6 +121,7 @@ class ConversationService: "tool_calls": tool_calls, "timestamp": current_time, "attachments": attachment_ids, + "model_id": model_id, } } }, @@ -146,7 +148,7 @@ class ConversationService: ] completion = llm.gen( - model=gpt_model, messages=messages_summary, max_tokens=30 + model=model_id, messages=messages_summary, max_tokens=30 ) conversation_data = { @@ -162,6 +164,7 @@ class ConversationService: "tool_calls": tool_calls, "timestamp": current_time, "attachments": attachment_ids, + "model_id": model_id, } ], } diff --git a/application/api/answer/services/stream_processor.py b/application/api/answer/services/stream_processor.py index bb890937..586e7696 100644 --- a/application/api/answer/services/stream_processor.py +++ b/application/api/answer/services/stream_processor.py @@ -12,12 +12,17 @@ from bson.objectid import ObjectId from application.agents.agent_creator import AgentCreator from application.api.answer.services.conversation_service import ConversationService from application.api.answer.services.prompt_renderer import PromptRenderer +from application.core.model_utils import ( + get_api_key_for_provider, + get_default_model_id, + get_provider_from_model_id, + validate_model_id, +) from application.core.mongo_db import MongoDB from application.core.settings import settings from application.retriever.retriever_creator import RetrieverCreator from application.utils import ( calculate_doc_token_budget, - get_gpt_model, limit_chat_history, ) @@ -83,7 +88,7 @@ class StreamProcessor: self.retriever_config = {} self.is_shared_usage = False self.shared_token = None - self.gpt_model = get_gpt_model() + self.model_id: Optional[str] = None self.conversation_service = ConversationService() self.prompt_renderer = PromptRenderer() self._prompt_content: Optional[str] = None @@ -91,6 +96,7 @@ class StreamProcessor: def initialize(self): """Initialize all required components for processing""" + self._validate_and_set_model() self._configure_agent() self._configure_source() self._configure_retriever() @@ -112,7 +118,7 @@ class StreamProcessor: ] else: self.history = limit_chat_history( - json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model + json.loads(self.data.get("history", "[]")), model_id=self.model_id ) def _process_attachments(self): @@ -143,6 +149,25 @@ class StreamProcessor: ) return attachments + def _validate_and_set_model(self): + """Validate and set model_id from request""" + from application.core.model_settings import ModelRegistry + + requested_model = self.data.get("model_id") + + if requested_model: + if not validate_model_id(requested_model): + registry = ModelRegistry.get_instance() + available_models = [m.id for m in registry.get_enabled_models()] + raise ValueError( + f"Invalid model_id '{requested_model}'. " + f"Available models: {', '.join(available_models[:5])}" + + (f" and {len(available_models) - 5} more" if len(available_models) > 5 else "") + ) + self.model_id = requested_model + else: + self.model_id = get_default_model_id() + def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple: """Get API key for agent with access control""" if not agent_id: @@ -322,7 +347,7 @@ class StreamProcessor: def _configure_retriever(self): history_token_limit = int(self.data.get("token_limit", 2000)) doc_token_limit = calculate_doc_token_budget( - gpt_model=self.gpt_model, history_token_limit=history_token_limit + model_id=self.model_id, history_token_limit=history_token_limit ) self.retriever_config = { @@ -344,7 +369,7 @@ class StreamProcessor: prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection), chunks=self.retriever_config["chunks"], doc_token_limit=self.retriever_config.get("doc_token_limit", 50000), - gpt_model=self.gpt_model, + model_id=self.model_id, user_api_key=self.agent_config["user_api_key"], decoded_token=self.decoded_token, ) @@ -626,12 +651,19 @@ class StreamProcessor: tools_data=tools_data, ) + provider = ( + get_provider_from_model_id(self.model_id) + if self.model_id + else settings.LLM_PROVIDER + ) + system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER) + return AgentCreator.create_agent( self.agent_config["agent_type"], endpoint="stream", - llm_name=settings.LLM_PROVIDER, - gpt_model=self.gpt_model, - api_key=settings.API_KEY, + llm_name=provider or settings.LLM_PROVIDER, + model_id=self.model_id, + api_key=system_api_key, user_api_key=self.agent_config["user_api_key"], prompt=rendered_prompt, chat_history=self.history, diff --git a/application/api/user/agents/routes.py b/application/api/user/agents/routes.py index 6c43342f..bcf68b56 100644 --- a/application/api/user/agents/routes.py +++ b/application/api/user/agents/routes.py @@ -95,6 +95,8 @@ class GetAgent(Resource): "shared": agent.get("shared_publicly", False), "shared_metadata": agent.get("shared_metadata", {}), "shared_token": agent.get("shared_token", ""), + "models": agent.get("models", []), + "default_model_id": agent.get("default_model_id", ""), } return make_response(jsonify(data), 200) except Exception as e: @@ -172,6 +174,8 @@ class GetAgents(Resource): "shared": agent.get("shared_publicly", False), "shared_metadata": agent.get("shared_metadata", {}), "shared_token": agent.get("shared_token", ""), + "models": agent.get("models", []), + "default_model_id": agent.get("default_model_id", ""), } for agent in agents if "source" in agent or "retriever" in agent @@ -230,6 +234,14 @@ class CreateAgent(Resource): required=False, description="Request limit for the agent in limited mode", ), + "models": fields.List( + fields.String, + required=False, + description="List of available model IDs for this agent", + ), + "default_model_id": fields.String( + required=False, description="Default model ID for this agent" + ), }, ) @@ -258,6 +270,11 @@ class CreateAgent(Resource): data["json_schema"] = json.loads(data["json_schema"]) except json.JSONDecodeError: data["json_schema"] = None + if "models" in data: + try: + data["models"] = json.loads(data["models"]) + except json.JSONDecodeError: + data["models"] = [] print(f"Received data: {data}") # Validate JSON schema if provided @@ -399,6 +416,8 @@ class CreateAgent(Resource): "updatedAt": datetime.datetime.now(datetime.timezone.utc), "lastUsedAt": None, "key": key, + "models": data.get("models", []), + "default_model_id": data.get("default_model_id", ""), } if new_agent["chunks"] == "": new_agent["chunks"] = "2" @@ -464,6 +483,14 @@ class UpdateAgent(Resource): required=False, description="Request limit for the agent in limited mode", ), + "models": fields.List( + fields.String, + required=False, + description="List of available model IDs for this agent", + ), + "default_model_id": fields.String( + required=False, description="Default model ID for this agent" + ), }, ) @@ -487,7 +514,7 @@ class UpdateAgent(Resource): data = request.get_json() else: data = request.form.to_dict() - json_fields = ["tools", "sources", "json_schema"] + json_fields = ["tools", "sources", "json_schema", "models"] for field in json_fields: if field in data and data[field]: try: @@ -555,6 +582,8 @@ class UpdateAgent(Resource): "token_limit", "limited_request_mode", "request_limit", + "models", + "default_model_id", ] for field in allowed_fields: diff --git a/application/api/user/models/__init__.py b/application/api/user/models/__init__.py new file mode 100644 index 00000000..f32afa11 --- /dev/null +++ b/application/api/user/models/__init__.py @@ -0,0 +1,3 @@ +from .routes import models_ns + +__all__ = ["models_ns"] diff --git a/application/api/user/models/routes.py b/application/api/user/models/routes.py new file mode 100644 index 00000000..886999b2 --- /dev/null +++ b/application/api/user/models/routes.py @@ -0,0 +1,25 @@ +from flask import current_app, jsonify, make_response +from flask_restx import Namespace, Resource + +from application.core.model_settings import ModelRegistry + +models_ns = Namespace("models", description="Available models", path="/api") + + +@models_ns.route("/models") +class ModelsListResource(Resource): + def get(self): + """Get list of available models with their capabilities.""" + try: + registry = ModelRegistry.get_instance() + models = registry.get_enabled_models() + + response = { + "models": [model.to_dict() for model in models], + "default_model_id": registry.default_model_id, + "count": len(models), + } + except Exception as err: + current_app.logger.error(f"Error fetching models: {err}", exc_info=True) + return make_response(jsonify({"success": False}), 500) + return make_response(jsonify(response), 200) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 1e0dbb4e..82e395c5 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -10,6 +10,7 @@ from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns from .analytics import analytics_ns from .attachments import attachments_ns from .conversations import conversations_ns +from .models import models_ns from .prompts import prompts_ns from .sharing import sharing_ns from .sources import sources_chunks_ns, sources_ns, sources_upload_ns @@ -27,6 +28,9 @@ api.add_namespace(attachments_ns) # Conversations api.add_namespace(conversations_ns) +# Models +api.add_namespace(models_ns) + # Agents (main, sharing, webhooks) api.add_namespace(agents_ns) api.add_namespace(agents_sharing_ns) diff --git a/application/core/model_configs.py b/application/core/model_configs.py new file mode 100644 index 00000000..b802ee27 --- /dev/null +++ b/application/core/model_configs.py @@ -0,0 +1,223 @@ +""" +Model configurations for all supported LLM providers. +""" + +from application.core.model_settings import ( + AvailableModel, + ModelCapabilities, + ModelProvider, +) + +OPENAI_ATTACHMENTS = [ + "application/pdf", + "image/png", + "image/jpeg", + "image/jpg", + "image/webp", + "image/gif", +] + +GOOGLE_ATTACHMENTS = [ + "application/pdf", + "image/png", + "image/jpeg", + "image/jpg", + "image/webp", + "image/gif", +] + + +OPENAI_MODELS = [ + AvailableModel( + id="gpt-4o", + provider=ModelProvider.OPENAI, + display_name="GPT-4 Omni", + description="Latest and most capable model", + capabilities=ModelCapabilities( + supports_tools=True, + supports_structured_output=True, + supported_attachment_types=OPENAI_ATTACHMENTS, + context_window=128000, + ), + ), + AvailableModel( + id="gpt-4o-mini", + provider=ModelProvider.OPENAI, + display_name="GPT-4 Omni Mini", + description="Fast and efficient", + capabilities=ModelCapabilities( + supports_tools=True, + supports_structured_output=True, + supported_attachment_types=OPENAI_ATTACHMENTS, + context_window=128000, + ), + ), + AvailableModel( + id="gpt-4-turbo", + provider=ModelProvider.OPENAI, + display_name="GPT-4 Turbo", + description="Fast GPT-4 with 128k context", + capabilities=ModelCapabilities( + supports_tools=True, + supports_structured_output=True, + supported_attachment_types=OPENAI_ATTACHMENTS, + context_window=128000, + ), + ), + AvailableModel( + id="gpt-4", + provider=ModelProvider.OPENAI, + display_name="GPT-4", + description="Most capable model", + capabilities=ModelCapabilities( + supports_tools=True, + supports_structured_output=True, + supported_attachment_types=OPENAI_ATTACHMENTS, + context_window=8192, + ), + ), + AvailableModel( + id="gpt-3.5-turbo", + provider=ModelProvider.OPENAI, + display_name="GPT-3.5 Turbo", + description="Fast and cost-effective", + capabilities=ModelCapabilities( + supports_tools=True, + context_window=4096, + ), + ), +] + + +ANTHROPIC_MODELS = [ + AvailableModel( + id="claude-3-5-sonnet-20241022", + provider=ModelProvider.ANTHROPIC, + display_name="Claude 3.5 Sonnet (Latest)", + description="Latest Claude 3.5 Sonnet with enhanced capabilities", + capabilities=ModelCapabilities( + supports_tools=True, + context_window=200000, + ), + ), + AvailableModel( + id="claude-3-5-sonnet", + provider=ModelProvider.ANTHROPIC, + display_name="Claude 3.5 Sonnet", + description="Balanced performance and capability", + capabilities=ModelCapabilities( + supports_tools=True, + context_window=200000, + ), + ), + AvailableModel( + id="claude-3-opus", + provider=ModelProvider.ANTHROPIC, + display_name="Claude 3 Opus", + description="Most capable Claude model", + capabilities=ModelCapabilities( + supports_tools=True, + context_window=200000, + ), + ), + AvailableModel( + id="claude-3-haiku", + provider=ModelProvider.ANTHROPIC, + display_name="Claude 3 Haiku", + description="Fastest Claude model", + capabilities=ModelCapabilities( + supports_tools=True, + context_window=200000, + ), + ), +] + + +GOOGLE_MODELS = [ + AvailableModel( + id="gemini-flash-latest", + provider=ModelProvider.GOOGLE, + display_name="Gemini Flash (Latest)", + description="Latest experimental Gemini model", + capabilities=ModelCapabilities( + supports_tools=True, + supports_structured_output=True, + supported_attachment_types=GOOGLE_ATTACHMENTS, + context_window=int(1e6), + ), + ), + AvailableModel( + id="gemini-flash-lite-latest", + provider=ModelProvider.GOOGLE, + display_name="Gemini Flash Lite (Latest)", + description="Fast with huge context window", + capabilities=ModelCapabilities( + supports_tools=True, + supports_structured_output=True, + supported_attachment_types=GOOGLE_ATTACHMENTS, + context_window=int(1e6), + ), + ), + AvailableModel( + id="gemini-2.5-pro", + provider=ModelProvider.GOOGLE, + display_name="Gemini 2.5 Pro", + description="Most capable Gemini model", + capabilities=ModelCapabilities( + supports_tools=True, + supports_structured_output=True, + supported_attachment_types=GOOGLE_ATTACHMENTS, + context_window=2000000, + ), + ), +] + + +GROQ_MODELS = [ + AvailableModel( + id="llama-3.3-70b-versatile", + provider=ModelProvider.GROQ, + display_name="Llama 3.3 70B", + description="Latest Llama model with high-speed inference", + capabilities=ModelCapabilities( + supports_tools=True, + context_window=128000, + ), + ), + AvailableModel( + id="llama-3.1-8b-instant", + provider=ModelProvider.GROQ, + display_name="Llama 3.1 8B", + description="Ultra-fast inference", + capabilities=ModelCapabilities( + supports_tools=True, + context_window=128000, + ), + ), + AvailableModel( + id="mixtral-8x7b-32768", + provider=ModelProvider.GROQ, + display_name="Mixtral 8x7B", + description="High-speed inference with tools", + capabilities=ModelCapabilities( + supports_tools=True, + context_window=32768, + ), + ), +] + + +AZURE_OPENAI_MODELS = [ + AvailableModel( + id="azure-gpt-4", + provider=ModelProvider.AZURE_OPENAI, + display_name="Azure OpenAI GPT-4", + description="Azure-hosted GPT model", + capabilities=ModelCapabilities( + supports_tools=True, + supports_structured_output=True, + supported_attachment_types=OPENAI_ATTACHMENTS, + context_window=8192, + ), + ), +] diff --git a/application/core/model_settings.py b/application/core/model_settings.py new file mode 100644 index 00000000..87325ac3 --- /dev/null +++ b/application/core/model_settings.py @@ -0,0 +1,236 @@ +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Optional + +logger = logging.getLogger(__name__) + + +class ModelProvider(str, Enum): + OPENAI = "openai" + AZURE_OPENAI = "azure_openai" + ANTHROPIC = "anthropic" + GROQ = "groq" + GOOGLE = "google" + HUGGINGFACE = "huggingface" + LLAMA_CPP = "llama.cpp" + DOCSGPT = "docsgpt" + PREMAI = "premai" + SAGEMAKER = "sagemaker" + NOVITA = "novita" + + +@dataclass +class ModelCapabilities: + supports_tools: bool = False + supports_structured_output: bool = False + supports_streaming: bool = True + supported_attachment_types: List[str] = field(default_factory=list) + context_window: int = 128000 + input_cost_per_token: Optional[float] = None + output_cost_per_token: Optional[float] = None + + +@dataclass +class AvailableModel: + id: str + provider: ModelProvider + display_name: str + description: str = "" + capabilities: ModelCapabilities = field(default_factory=ModelCapabilities) + enabled: bool = True + base_url: Optional[str] = None + + def to_dict(self) -> Dict: + result = { + "id": self.id, + "provider": self.provider.value, + "display_name": self.display_name, + "description": self.description, + "supported_attachment_types": self.capabilities.supported_attachment_types, + "supports_tools": self.capabilities.supports_tools, + "supports_structured_output": self.capabilities.supports_structured_output, + "supports_streaming": self.capabilities.supports_streaming, + "context_window": self.capabilities.context_window, + "enabled": self.enabled, + } + if self.base_url: + result["base_url"] = self.base_url + return result + + +class ModelRegistry: + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if not ModelRegistry._initialized: + self.models: Dict[str, AvailableModel] = {} + self.default_model_id: Optional[str] = None + self._load_models() + ModelRegistry._initialized = True + + @classmethod + def get_instance(cls) -> "ModelRegistry": + return cls() + + def _load_models(self): + from application.core.settings import settings + + self.models.clear() + + self._add_docsgpt_models(settings) + if settings.OPENAI_API_KEY or ( + settings.LLM_PROVIDER == "openai" and settings.API_KEY + ): + self._add_openai_models(settings) + if settings.OPENAI_API_BASE or ( + settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY + ): + self._add_azure_openai_models(settings) + if settings.ANTHROPIC_API_KEY or ( + settings.LLM_PROVIDER == "anthropic" and settings.API_KEY + ): + self._add_anthropic_models(settings) + if settings.GOOGLE_API_KEY or ( + settings.LLM_PROVIDER == "google" and settings.API_KEY + ): + self._add_google_models(settings) + if settings.GROQ_API_KEY or ( + settings.LLM_PROVIDER == "groq" and settings.API_KEY + ): + self._add_groq_models(settings) + if settings.HUGGINGFACE_API_KEY or ( + settings.LLM_PROVIDER == "huggingface" and settings.API_KEY + ): + self._add_huggingface_models(settings) + # Default model selection + + if settings.LLM_NAME and settings.LLM_NAME in self.models: + self.default_model_id = settings.LLM_NAME + elif settings.LLM_PROVIDER and settings.API_KEY: + for model_id, model in self.models.items(): + if model.provider.value == settings.LLM_PROVIDER: + self.default_model_id = model_id + break + else: + self.default_model_id = next(iter(self.models.keys())) + logger.info( + f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}" + ) + + def _add_openai_models(self, settings): + from application.core.model_configs import OPENAI_MODELS + + if settings.OPENAI_API_KEY: + for model in OPENAI_MODELS: + self.models[model.id] = model + return + if settings.LLM_PROVIDER == "openai" and settings.LLM_NAME: + for model in OPENAI_MODELS: + if model.id == settings.LLM_NAME: + self.models[model.id] = model + return + for model in OPENAI_MODELS: + self.models[model.id] = model + + def _add_azure_openai_models(self, settings): + from application.core.model_configs import AZURE_OPENAI_MODELS + + if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME: + for model in AZURE_OPENAI_MODELS: + if model.id == settings.LLM_NAME: + self.models[model.id] = model + return + for model in AZURE_OPENAI_MODELS: + self.models[model.id] = model + + def _add_anthropic_models(self, settings): + from application.core.model_configs import ANTHROPIC_MODELS + + if settings.ANTHROPIC_API_KEY: + for model in ANTHROPIC_MODELS: + self.models[model.id] = model + return + if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME: + for model in ANTHROPIC_MODELS: + if model.id == settings.LLM_NAME: + self.models[model.id] = model + return + for model in ANTHROPIC_MODELS: + self.models[model.id] = model + + def _add_google_models(self, settings): + from application.core.model_configs import GOOGLE_MODELS + + if settings.GOOGLE_API_KEY: + for model in GOOGLE_MODELS: + self.models[model.id] = model + return + if settings.LLM_PROVIDER == "google" and settings.LLM_NAME: + for model in GOOGLE_MODELS: + if model.id == settings.LLM_NAME: + self.models[model.id] = model + return + for model in GOOGLE_MODELS: + self.models[model.id] = model + + def _add_groq_models(self, settings): + from application.core.model_configs import GROQ_MODELS + + if settings.GROQ_API_KEY: + for model in GROQ_MODELS: + self.models[model.id] = model + return + if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME: + for model in GROQ_MODELS: + if model.id == settings.LLM_NAME: + self.models[model.id] = model + return + for model in GROQ_MODELS: + self.models[model.id] = model + + def _add_docsgpt_models(self, settings): + model_id = "docsgpt-local" + model = AvailableModel( + id=model_id, + provider=ModelProvider.DOCSGPT, + display_name="DocsGPT Model", + description="Local model", + capabilities=ModelCapabilities( + supports_tools=False, + supported_attachment_types=[], + ), + ) + self.models[model_id] = model + + def _add_huggingface_models(self, settings): + model_id = "huggingface-local" + model = AvailableModel( + id=model_id, + provider=ModelProvider.HUGGINGFACE, + display_name="Hugging Face Model", + description="Local Hugging Face model", + capabilities=ModelCapabilities( + supports_tools=False, + supported_attachment_types=[], + ), + ) + self.models[model_id] = model + + def get_model(self, model_id: str) -> Optional[AvailableModel]: + return self.models.get(model_id) + + def get_all_models(self) -> List[AvailableModel]: + return list(self.models.values()) + + def get_enabled_models(self) -> List[AvailableModel]: + return [m for m in self.models.values() if m.enabled] + + def model_exists(self, model_id: str) -> bool: + return model_id in self.models diff --git a/application/core/model_utils.py b/application/core/model_utils.py new file mode 100644 index 00000000..f24dbf47 --- /dev/null +++ b/application/core/model_utils.py @@ -0,0 +1,91 @@ +from typing import Any, Dict, Optional + +from application.core.model_settings import ModelRegistry + + +def get_api_key_for_provider(provider: str) -> Optional[str]: + """Get the appropriate API key for a provider""" + from application.core.settings import settings + + provider_key_map = { + "openai": settings.OPENAI_API_KEY, + "anthropic": settings.ANTHROPIC_API_KEY, + "google": settings.GOOGLE_API_KEY, + "groq": settings.GROQ_API_KEY, + "huggingface": settings.HUGGINGFACE_API_KEY, + "azure_openai": settings.API_KEY, + "docsgpt": None, + "llama.cpp": None, + } + + provider_key = provider_key_map.get(provider) + if provider_key: + return provider_key + return settings.API_KEY + + +def get_all_available_models() -> Dict[str, Dict[str, Any]]: + """Get all available models with metadata for API response""" + registry = ModelRegistry.get_instance() + return {model.id: model.to_dict() for model in registry.get_enabled_models()} + + +def validate_model_id(model_id: str) -> bool: + """Check if a model ID exists in registry""" + registry = ModelRegistry.get_instance() + return registry.model_exists(model_id) + + +def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]: + """Get capabilities for a specific model""" + registry = ModelRegistry.get_instance() + model = registry.get_model(model_id) + if model: + return { + "supported_attachment_types": model.capabilities.supported_attachment_types, + "supports_tools": model.capabilities.supports_tools, + "supports_structured_output": model.capabilities.supports_structured_output, + "context_window": model.capabilities.context_window, + } + return None + + +def get_default_model_id() -> str: + """Get the system default model ID""" + registry = ModelRegistry.get_instance() + return registry.default_model_id + + +def get_provider_from_model_id(model_id: str) -> Optional[str]: + """Get the provider name for a given model_id""" + registry = ModelRegistry.get_instance() + model = registry.get_model(model_id) + if model: + return model.provider.value + return None + + +def get_token_limit(model_id: str) -> int: + """ + Get context window (token limit) for a model. + Returns model's context_window or default 128000 if model not found. + """ + from application.core.settings import settings + + registry = ModelRegistry.get_instance() + model = registry.get_model(model_id) + if model: + return model.capabilities.context_window + return settings.DEFAULT_LLM_TOKEN_LIMIT + + +def get_base_url_for_model(model_id: str) -> Optional[str]: + """ + Get the custom base_url for a specific model if configured. + Returns None if no custom base_url is set. + """ + registry = ModelRegistry.get_instance() + model = registry.get_model(model_id) + if model: + return model.base_url + return None diff --git a/application/core/settings.py b/application/core/settings.py index 22116a7c..ee7ffa05 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -22,15 +22,7 @@ class Settings(BaseSettings): MONGO_DB_NAME: str = "docsgpt" LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf") DEFAULT_MAX_HISTORY: int = 150 - LLM_TOKEN_LIMITS: dict = { - "gpt-4o": 128000, - "gpt-4o-mini": 128000, - "gpt-4": 8192, - "gpt-3.5-turbo": 4096, - "claude-2": int(1e5), - "gemini-2.5-flash": int(1e6), - } - DEFAULT_LLM_TOKEN_LIMIT: int = 128000 + DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry RESERVED_TOKENS: dict = { "system_prompt": 500, "current_query": 500, @@ -64,14 +56,22 @@ class Settings(BaseSettings): ) # GitHub source - GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access + GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access # LLM Cache CACHE_REDIS_URL: str = "redis://localhost:6379/2" API_URL: str = "http://localhost:7091" # backend url for celery worker - API_KEY: Optional[str] = None # LLM api key + API_KEY: Optional[str] = None # LLM api key (used by LLM_PROVIDER) + + # Provider-specific API keys (for multi-model support) + OPENAI_API_KEY: Optional[str] = None + ANTHROPIC_API_KEY: Optional[str] = None + GOOGLE_API_KEY: Optional[str] = None + GROQ_API_KEY: Optional[str] = None + HUGGINGFACE_API_KEY: Optional[str] = None + EMBEDDINGS_KEY: Optional[str] = ( None # api key for embeddings (if using openai, just copy API_KEY) ) @@ -138,11 +138,12 @@ class Settings(BaseSettings): # Encryption settings ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key" - TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs + TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs ELEVENLABS_API_KEY: Optional[str] = None # Tool pre-fetch settings ENABLE_TOOL_PREFETCH: bool = True + path = Path(__file__).parent.parent.absolute() settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8") diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index b55dd855..4d26f925 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -1,30 +1,41 @@ -from application.llm.base import BaseLLM +from anthropic import AI_PROMPT, Anthropic, HUMAN_PROMPT + from application.core.settings import settings +from application.llm.base import BaseLLM class AnthropicLLM(BaseLLM): - def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): - from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT + def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs): super().__init__(*args, **kwargs) - self.api_key = ( - api_key or settings.ANTHROPIC_API_KEY - ) # If not provided, use a default from settings + self.api_key = api_key or settings.ANTHROPIC_API_KEY or settings.API_KEY self.user_api_key = user_api_key - self.anthropic = Anthropic(api_key=self.api_key) + + # Use custom base_url if provided + if base_url: + self.anthropic = Anthropic(api_key=self.api_key, base_url=base_url) + else: + self.anthropic = Anthropic(api_key=self.api_key) + self.HUMAN_PROMPT = HUMAN_PROMPT self.AI_PROMPT = AI_PROMPT def _raw_gen( - self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs + self, + baseself, + model, + messages, + stream=False, + tools=None, + max_tokens=300, + **kwargs, ): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Context \n {context} \n ### Question \n {user_question}" if stream: return self.gen_stream(model, prompt, stream, max_tokens, **kwargs) - completion = self.anthropic.completions.create( model=model, max_tokens_to_sample=max_tokens, @@ -34,7 +45,14 @@ class AnthropicLLM(BaseLLM): return completion.completion def _raw_gen_stream( - self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs + self, + baseself, + model, + messages, + stream=True, + tools=None, + max_tokens=300, + **kwargs, ): context = messages[0]["content"] user_question = messages[-1]["content"] @@ -50,5 +68,5 @@ class AnthropicLLM(BaseLLM): for completion in stream_response: yield completion.completion finally: - if hasattr(stream_response, 'close'): + if hasattr(stream_response, "close"): stream_response.close() diff --git a/application/llm/base.py b/application/llm/base.py index c16ec99e..e8ee78eb 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -13,30 +13,32 @@ class BaseLLM(ABC): def __init__( self, decoded_token=None, + model_id=None, + base_url=None, ): self.decoded_token = decoded_token + self.model_id = model_id + self.base_url = base_url self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} - self.fallback_provider = settings.FALLBACK_LLM_PROVIDER - self.fallback_model_name = settings.FALLBACK_LLM_NAME - self.fallback_llm_api_key = settings.FALLBACK_LLM_API_KEY self._fallback_llm = None + self._fallback_sequence_index = 0 @property def fallback_llm(self): - """Lazy-loaded fallback LLM instance.""" - if ( - self._fallback_llm is None - and self.fallback_provider - and self.fallback_model_name - ): + """Lazy-loaded fallback LLM from FALLBACK_* settings.""" + if self._fallback_llm is None and settings.FALLBACK_LLM_PROVIDER: try: from application.llm.llm_creator import LLMCreator self._fallback_llm = LLMCreator.create_llm( - self.fallback_provider, - self.fallback_llm_api_key, - None, - self.decoded_token, + settings.FALLBACK_LLM_PROVIDER, + api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY, + user_api_key=None, + decoded_token=self.decoded_token, + model_id=settings.FALLBACK_LLM_NAME, + ) + logger.info( + f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}" ) except Exception as e: logger.error( @@ -54,7 +56,7 @@ class BaseLLM(ABC): self, method_name: str, decorators: list, *args, **kwargs ): """ - Unified method execution with fallback support. + Execute method with fallback support. Args: method_name: Name of the raw method ('_raw_gen' or '_raw_gen_stream') @@ -73,10 +75,10 @@ class BaseLLM(ABC): return decorated_method() except Exception as e: if not self.fallback_llm: - logger.error(f"Primary LLM failed and no fallback available: {str(e)}") + logger.error(f"Primary LLM failed and no fallback configured: {str(e)}") raise logger.warning( - f"Falling back to {self.fallback_provider}/{self.fallback_model_name}. Error: {str(e)}" + f"Primary LLM failed. Falling back to {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}. Error: {str(e)}" ) fallback_method = getattr( diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py index 3572db40..44a479ae 100644 --- a/application/llm/docsgpt_provider.py +++ b/application/llm/docsgpt_provider.py @@ -1,5 +1,7 @@ import json +from openai import OpenAI + from application.core.settings import settings from application.llm.base import BaseLLM @@ -7,12 +9,11 @@ from application.llm.base import BaseLLM class DocsGPTAPILLM(BaseLLM): def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): - from openai import OpenAI super().__init__(*args, **kwargs) - self.client = OpenAI(api_key="sk-docsgpt-public", base_url="https://oai.arc53.com") + self.api_key = "sk-docsgpt-public" + self.client = OpenAI(api_key=self.api_key, base_url="https://oai.arc53.com") self.user_api_key = user_api_key - self.api_key = api_key def _clean_messages_openai(self, messages): cleaned_messages = [] @@ -22,7 +23,6 @@ class DocsGPTAPILLM(BaseLLM): if role == "model": role = "assistant" - if role and content is not None: if isinstance(content, str): cleaned_messages.append({"role": role, "content": content}) @@ -69,7 +69,6 @@ class DocsGPTAPILLM(BaseLLM): ) else: raise ValueError(f"Unexpected content type: {type(content)}") - return cleaned_messages def _raw_gen( @@ -121,7 +120,6 @@ class DocsGPTAPILLM(BaseLLM): response = self.client.chat.completions.create( model="docsgpt", messages=messages, stream=stream, **kwargs ) - try: for line in response: if ( @@ -133,8 +131,8 @@ class DocsGPTAPILLM(BaseLLM): elif len(line.choices) > 0: yield line.choices[0] finally: - if hasattr(response, 'close'): + if hasattr(response, "close"): response.close() def _supports_tools(self): - return True \ No newline at end of file + return True diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index 47be51cd..9c58a3e1 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -13,8 +13,9 @@ from application.storage.storage_creator import StorageCreator class GoogleLLM(BaseLLM): def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): super().__init__(*args, **kwargs) - self.api_key = api_key + self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY self.user_api_key = user_api_key + self.client = genai.Client(api_key=self.api_key) self.storage = StorageCreator.get_storage() @@ -47,21 +48,19 @@ class GoogleLLM(BaseLLM): """ if not attachments: return messages - prepared_messages = messages.copy() # Find the user message to attach files to the last one + user_message_index = None for i in range(len(prepared_messages) - 1, -1, -1): if prepared_messages[i].get("role") == "user": user_message_index = i break - if user_message_index is None: user_message = {"role": "user", "content": []} prepared_messages.append(user_message) user_message_index = len(prepared_messages) - 1 - if isinstance(prepared_messages[user_message_index].get("content"), str): text_content = prepared_messages[user_message_index]["content"] prepared_messages[user_message_index]["content"] = [ @@ -69,7 +68,6 @@ class GoogleLLM(BaseLLM): ] elif not isinstance(prepared_messages[user_message_index].get("content"), list): prepared_messages[user_message_index]["content"] = [] - files = [] for attachment in attachments: mime_type = attachment.get("mime_type") @@ -92,11 +90,9 @@ class GoogleLLM(BaseLLM): "text": f"[File could not be processed: {attachment.get('path', 'unknown')}]", } ) - if files: logging.info(f"GoogleLLM: Adding {len(files)} files to message") prepared_messages[user_message_index]["content"].append({"files": files}) - return prepared_messages def _upload_file_to_google(self, attachment): @@ -111,14 +107,11 @@ class GoogleLLM(BaseLLM): """ if "google_file_uri" in attachment: return attachment["google_file_uri"] - file_path = attachment.get("path") if not file_path: raise ValueError("No file path provided in attachment") - if not self.storage.file_exists(file_path): raise FileNotFoundError(f"File not found: {file_path}") - try: file_uri = self.storage.process_file( file_path, @@ -136,7 +129,6 @@ class GoogleLLM(BaseLLM): attachments_collection.update_one( {"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}} ) - return file_uri except Exception as e: logging.error(f"Error uploading file to Google AI: {e}", exc_info=True) @@ -153,7 +145,6 @@ class GoogleLLM(BaseLLM): role = "model" elif role == "tool": role = "model" - parts = [] if role and content is not None: if isinstance(content, str): @@ -164,6 +155,7 @@ class GoogleLLM(BaseLLM): parts.append(types.Part.from_text(text=item["text"])) elif "function_call" in item: # Remove null values from args to avoid API errors + cleaned_args = self._remove_null_values( item["function_call"]["args"] ) @@ -194,10 +186,8 @@ class GoogleLLM(BaseLLM): ) else: raise ValueError(f"Unexpected content type: {type(content)}") - if parts: cleaned_messages.append(types.Content(role=role, parts=parts)) - return cleaned_messages def _clean_schema(self, schema_obj): @@ -233,8 +223,8 @@ class GoogleLLM(BaseLLM): cleaned[key] = [self._clean_schema(item) for item in value] else: cleaned[key] = value - # Validate that required properties actually exist in properties + if "required" in cleaned and "properties" in cleaned: valid_required = [] properties_keys = set(cleaned["properties"].keys()) @@ -247,7 +237,6 @@ class GoogleLLM(BaseLLM): cleaned.pop("required", None) elif "required" in cleaned and "properties" not in cleaned: cleaned.pop("required", None) - return cleaned def _clean_tools_format(self, tools_list): @@ -263,7 +252,6 @@ class GoogleLLM(BaseLLM): cleaned_properties = {} for k, v in properties.items(): cleaned_properties[k] = self._clean_schema(v) - genai_function = dict( name=function["name"], description=function["description"], @@ -282,10 +270,8 @@ class GoogleLLM(BaseLLM): name=function["name"], description=function["description"], ) - genai_tool = types.Tool(function_declarations=[genai_function]) genai_tools.append(genai_tool) - return genai_tools def _raw_gen( @@ -307,16 +293,14 @@ class GoogleLLM(BaseLLM): if messages[0].role == "system": config.system_instruction = messages[0].parts[0].text messages = messages[1:] - if tools: cleaned_tools = self._clean_tools_format(tools) config.tools = cleaned_tools - # Add response schema for structured output if provided + if response_schema: config.response_schema = response_schema config.response_mime_type = "application/json" - response = client.models.generate_content( model=model, contents=messages, @@ -347,17 +331,16 @@ class GoogleLLM(BaseLLM): if messages[0].role == "system": config.system_instruction = messages[0].parts[0].text messages = messages[1:] - if tools: cleaned_tools = self._clean_tools_format(tools) config.tools = cleaned_tools - # Add response schema for structured output if provided + if response_schema: config.response_schema = response_schema config.response_mime_type = "application/json" - # Check if we have both tools and file attachments + has_attachments = False for message in messages: for part in message.parts: @@ -366,7 +349,6 @@ class GoogleLLM(BaseLLM): break if has_attachments: break - logging.info( f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}" ) @@ -405,7 +387,6 @@ class GoogleLLM(BaseLLM): """Convert JSON schema to Google AI structured output format.""" if not json_schema: return None - type_map = { "object": "OBJECT", "array": "ARRAY", @@ -418,12 +399,10 @@ class GoogleLLM(BaseLLM): def convert(schema): if not isinstance(schema, dict): return schema - result = {} schema_type = schema.get("type") if schema_type: result["type"] = type_map.get(schema_type.lower(), schema_type.upper()) - for key in [ "description", "nullable", @@ -435,7 +414,6 @@ class GoogleLLM(BaseLLM): ]: if key in schema: result[key] = schema[key] - if "format" in schema: format_value = schema["format"] if schema_type == "string": @@ -445,21 +423,17 @@ class GoogleLLM(BaseLLM): result["format"] = format_value else: result["format"] = format_value - if "properties" in schema: result["properties"] = { k: convert(v) for k, v in schema["properties"].items() } if "propertyOrdering" not in result and result.get("type") == "OBJECT": result["propertyOrdering"] = list(result["properties"].keys()) - if "items" in schema: result["items"] = convert(schema["items"]) - for field in ["anyOf", "oneOf", "allOf"]: if field in schema: result[field] = [convert(s) for s in schema[field]] - return result try: diff --git a/application/llm/groq.py b/application/llm/groq.py index 282d7f47..c2ae40ee 100644 --- a/application/llm/groq.py +++ b/application/llm/groq.py @@ -1,13 +1,18 @@ -from application.llm.base import BaseLLM from openai import OpenAI +from application.core.settings import settings +from application.llm.base import BaseLLM + class GroqLLM(BaseLLM): def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): + super().__init__(*args, **kwargs) - self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1") - self.api_key = api_key + self.api_key = api_key or settings.GROQ_API_KEY or settings.API_KEY self.user_api_key = user_api_key + self.client = OpenAI( + api_key=self.api_key, base_url="https://api.groq.com/openai/v1" + ) def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs): if tools: diff --git a/application/llm/handlers/base.py b/application/llm/handlers/base.py index 96ed4c00..920caf65 100644 --- a/application/llm/handlers/base.py +++ b/application/llm/handlers/base.py @@ -282,7 +282,7 @@ class LLMHandler(ABC): messages = e.value break response = agent.llm.gen( - model=agent.gpt_model, messages=messages, tools=agent.tools + model=agent.model_id, messages=messages, tools=agent.tools ) parsed = self.parse_response(response) self.llm_calls.append(build_stack_data(agent.llm)) @@ -337,7 +337,7 @@ class LLMHandler(ABC): tool_calls = {} response = agent.llm.gen_stream( - model=agent.gpt_model, messages=messages, tools=agent.tools + model=agent.model_id, messages=messages, tools=agent.tools ) self.llm_calls.append(build_stack_data(agent.llm)) diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py index 3ed23854..21d653b9 100644 --- a/application/llm/llm_creator.py +++ b/application/llm/llm_creator.py @@ -1,13 +1,17 @@ -from application.llm.groq import GroqLLM -from application.llm.openai import OpenAILLM, AzureOpenAILLM -from application.llm.sagemaker import SagemakerAPILLM -from application.llm.huggingface import HuggingFaceLLM -from application.llm.llama_cpp import LlamaCpp +import logging + from application.llm.anthropic import AnthropicLLM from application.llm.docsgpt_provider import DocsGPTAPILLM -from application.llm.premai import PremAILLM from application.llm.google_ai import GoogleLLM +from application.llm.groq import GroqLLM +from application.llm.huggingface import HuggingFaceLLM +from application.llm.llama_cpp import LlamaCpp from application.llm.novita import NovitaLLM +from application.llm.openai import AzureOpenAILLM, OpenAILLM +from application.llm.premai import PremAILLM +from application.llm.sagemaker import SagemakerAPILLM + +logger = logging.getLogger(__name__) class LLMCreator: @@ -26,10 +30,26 @@ class LLMCreator: } @classmethod - def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs): + def create_llm( + cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs + ): + from application.core.model_utils import get_base_url_for_model + llm_class = cls.llms.get(type.lower()) if not llm_class: raise ValueError(f"No LLM class found for type {type}") + + # Extract base_url from model configuration if model_id is provided + base_url = None + if model_id: + base_url = get_base_url_for_model(model_id) + return llm_class( - api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs + api_key, + user_api_key, + decoded_token=decoded_token, + model_id=model_id, + base_url=base_url, + *args, + **kwargs, ) diff --git a/application/llm/openai.py b/application/llm/openai.py index de24e2c5..beab465b 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -2,6 +2,8 @@ import base64 import json import logging +from openai import OpenAI + from application.core.settings import settings from application.llm.base import BaseLLM from application.storage.storage_creator import StorageCreator @@ -9,20 +11,25 @@ from application.storage.storage_creator import StorageCreator class OpenAILLM(BaseLLM): - def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): - from openai import OpenAI + def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs): super().__init__(*args, **kwargs) - if ( + self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY + self.user_api_key = user_api_key + + # Priority: 1) Parameter base_url, 2) Settings OPENAI_BASE_URL, 3) Default + effective_base_url = None + if base_url and isinstance(base_url, str) and base_url.strip(): + effective_base_url = base_url + elif ( isinstance(settings.OPENAI_BASE_URL, str) and settings.OPENAI_BASE_URL.strip() ): - self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL) + effective_base_url = settings.OPENAI_BASE_URL else: - DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1" - self.client = OpenAI(api_key=api_key, base_url=DEFAULT_OPENAI_API_BASE) - self.api_key = api_key - self.user_api_key = user_api_key + effective_base_url = "https://api.openai.com/v1" + + self.client = OpenAI(api_key=self.api_key, base_url=effective_base_url) self.storage = StorageCreator.get_storage() def _clean_messages_openai(self, messages): @@ -33,7 +40,6 @@ class OpenAILLM(BaseLLM): if role == "model": role = "assistant" - if role and content is not None: if isinstance(content, str): cleaned_messages.append({"role": role, "content": content}) @@ -107,7 +113,6 @@ class OpenAILLM(BaseLLM): ) else: raise ValueError(f"Unexpected content type: {type(content)}") - return cleaned_messages def _raw_gen( @@ -132,10 +137,8 @@ class OpenAILLM(BaseLLM): if tools: request_params["tools"] = tools - if response_format: request_params["response_format"] = response_format - response = self.client.chat.completions.create(**request_params) if tools: @@ -165,10 +168,8 @@ class OpenAILLM(BaseLLM): if tools: request_params["tools"] = tools - if response_format: request_params["response_format"] = response_format - response = self.client.chat.completions.create(**request_params) try: @@ -194,7 +195,6 @@ class OpenAILLM(BaseLLM): def prepare_structured_output_format(self, json_schema): if not json_schema: return None - try: def add_additional_properties_false(schema_obj): @@ -204,11 +204,11 @@ class OpenAILLM(BaseLLM): if schema_copy.get("type") == "object": schema_copy["additionalProperties"] = False # Ensure 'required' includes all properties for OpenAI strict mode + if "properties" in schema_copy: schema_copy["required"] = list( schema_copy["properties"].keys() ) - for key, value in schema_copy.items(): if key == "properties" and isinstance(value, dict): schema_copy[key] = { @@ -224,7 +224,6 @@ class OpenAILLM(BaseLLM): add_additional_properties_false(sub_schema) for sub_schema in value ] - return schema_copy return schema_obj @@ -243,7 +242,6 @@ class OpenAILLM(BaseLLM): } return result - except Exception as e: logging.error(f"Error preparing structured output format: {e}") return None @@ -277,21 +275,19 @@ class OpenAILLM(BaseLLM): """ if not attachments: return messages - prepared_messages = messages.copy() # Find the user message to attach file_id to the last one + user_message_index = None for i in range(len(prepared_messages) - 1, -1, -1): if prepared_messages[i].get("role") == "user": user_message_index = i break - if user_message_index is None: user_message = {"role": "user", "content": []} prepared_messages.append(user_message) user_message_index = len(prepared_messages) - 1 - if isinstance(prepared_messages[user_message_index].get("content"), str): text_content = prepared_messages[user_message_index]["content"] prepared_messages[user_message_index]["content"] = [ @@ -299,7 +295,6 @@ class OpenAILLM(BaseLLM): ] elif not isinstance(prepared_messages[user_message_index].get("content"), list): prepared_messages[user_message_index]["content"] = [] - for attachment in attachments: mime_type = attachment.get("mime_type") @@ -326,6 +321,7 @@ class OpenAILLM(BaseLLM): } ) # Handle PDFs using the file API + elif mime_type == "application/pdf": try: file_id = self._upload_file_to_openai(attachment) @@ -341,7 +337,6 @@ class OpenAILLM(BaseLLM): "text": f"File content:\n\n{attachment['content']}", } ) - return prepared_messages def _get_base64_image(self, attachment): @@ -357,7 +352,6 @@ class OpenAILLM(BaseLLM): file_path = attachment.get("path") if not file_path: raise ValueError("No file path provided in attachment") - try: with self.storage.get_file(file_path) as image_file: return base64.b64encode(image_file.read()).decode("utf-8") @@ -381,12 +375,10 @@ class OpenAILLM(BaseLLM): if "openai_file_id" in attachment: return attachment["openai_file_id"] - file_path = attachment.get("path") if not self.storage.file_exists(file_path): raise FileNotFoundError(f"File not found: {file_path}") - try: file_id = self.storage.process_file( file_path, @@ -404,7 +396,6 @@ class OpenAILLM(BaseLLM): attachments_collection.update_one( {"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}} ) - return file_id except Exception as e: logging.error(f"Error uploading file to OpenAI: {e}", exc_info=True) diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index f0428a26..f0468a18 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -16,7 +16,7 @@ class ClassicRAG(BaseRetriever): prompt="", chunks=2, doc_token_limit=50000, - gpt_model="docsgpt", + model_id="docsgpt-local", user_api_key=None, llm_name=settings.LLM_PROVIDER, api_key=settings.API_KEY, @@ -40,7 +40,7 @@ class ClassicRAG(BaseRetriever): f"ClassicRAG initialized with chunks={self.chunks}, user_api_key={user_identifier}, " f"sources={'active_docs' in source and source['active_docs'] is not None}" ) - self.gpt_model = gpt_model + self.model_id = model_id self.doc_token_limit = doc_token_limit self.user_api_key = user_api_key self.llm_name = llm_name @@ -100,7 +100,7 @@ class ClassicRAG(BaseRetriever): ] try: - rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages) + rephrased_query = self.llm.gen(model=self.model_id, messages=messages) print(f"Rephrased query: {rephrased_query}") return rephrased_query if rephrased_query else self.original_question except Exception as e: diff --git a/application/utils.py b/application/utils.py index 06eaf495..89b884f0 100644 --- a/application/utils.py +++ b/application/utils.py @@ -7,6 +7,8 @@ import tiktoken from flask import jsonify, make_response from werkzeug.utils import secure_filename +from application.core.model_utils import get_token_limit + from application.core.settings import settings @@ -75,11 +77,9 @@ def count_tokens_docs(docs): def calculate_doc_token_budget( - gpt_model: str = "gpt-4o", history_token_limit: int = 2000 + model_id: str = "gpt-4o", history_token_limit: int = 2000 ) -> int: - total_context = settings.LLM_TOKEN_LIMITS.get( - gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT - ) + total_context = get_token_limit(model_id) reserved = sum(settings.RESERVED_TOKENS.values()) doc_budget = total_context - history_token_limit - reserved return max(doc_budget, 1000) @@ -144,16 +144,13 @@ def get_hash(data): return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest() -def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): +def limit_chat_history(history, max_token_limit=None, model_id="docsgpt-local"): """Limit chat history to fit within token limit.""" - from application.core.settings import settings - + model_token_limit = get_token_limit(model_id) max_token_limit = ( max_token_limit - if max_token_limit - and max_token_limit - < settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT) - else settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT) + if max_token_limit and max_token_limit < model_token_limit + else model_token_limit ) if not history: @@ -205,37 +202,44 @@ def clean_text_for_tts(text: str) -> str: clean text for Text-to-Speech processing. """ # Handle code blocks and links - text = re.sub(r'```mermaid[\s\S]*?```', ' flowchart, ', text) ## ```mermaid...``` - text = re.sub(r'```[\s\S]*?```', ' code block, ', text) ## ```code``` - text = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', text) ## [text](url) - text = re.sub(r'!\[([^\]]*)\]\([^\)]+\)', '', text) ## ![alt](url) + + text = re.sub(r"```mermaid[\s\S]*?```", " flowchart, ", text) ## ```mermaid...``` + text = re.sub(r"```[\s\S]*?```", " code block, ", text) ## ```code``` + text = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", text) ## [text](url) + text = re.sub(r"!\[([^\]]*)\]\([^\)]+\)", "", text) ## ![alt](url) # Remove markdown formatting - text = re.sub(r'`([^`]+)`', r'\1', text) ## `code` - text = re.sub(r'\{([^}]*)\}', r' \1 ', text) ## {text} - text = re.sub(r'[{}]', ' ', text) ## unmatched {} - text = re.sub(r'\[([^\]]+)\]', r' \1 ', text) ## [text] - text = re.sub(r'[\[\]]', ' ', text) ## unmatched [] - text = re.sub(r'(\*\*|__)(.*?)\1', r'\2', text) ## **bold** __bold__ - text = re.sub(r'(\*|_)(.*?)\1', r'\2', text) ## *italic* _italic_ - text = re.sub(r'^#{1,6}\s+', '', text, flags=re.MULTILINE) ## # headers - text = re.sub(r'^>\s+', '', text, flags=re.MULTILINE) ## > blockquotes - text = re.sub(r'^[\s]*[-\*\+]\s+', '', text, flags=re.MULTILINE) ## - * + lists - text = re.sub(r'^[\s]*\d+\.\s+', '', text, flags=re.MULTILINE) ## 1. numbered lists - text = re.sub(r'^[\*\-_]{3,}\s*$', '', text, flags=re.MULTILINE) ## --- *** ___ rules - text = re.sub(r'<[^>]*>', '', text) ## tags - #Remove non-ASCII (emojis, special Unicode) - text = re.sub(r'[^\x20-\x7E\n\r\t]', '', text) + text = re.sub(r"`([^`]+)`", r"\1", text) ## `code` + text = re.sub(r"\{([^}]*)\}", r" \1 ", text) ## {text} + text = re.sub(r"[{}]", " ", text) ## unmatched {} + text = re.sub(r"\[([^\]]+)\]", r" \1 ", text) ## [text] + text = re.sub(r"[\[\]]", " ", text) ## unmatched [] + text = re.sub(r"(\*\*|__)(.*?)\1", r"\2", text) ## **bold** __bold__ + text = re.sub(r"(\*|_)(.*?)\1", r"\2", text) ## *italic* _italic_ + text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE) ## # headers + text = re.sub(r"^>\s+", "", text, flags=re.MULTILINE) ## > blockquotes + text = re.sub(r"^[\s]*[-\*\+]\s+", "", text, flags=re.MULTILINE) ## - * + lists + text = re.sub(r"^[\s]*\d+\.\s+", "", text, flags=re.MULTILINE) ## 1. numbered lists + text = re.sub( + r"^[\*\-_]{3,}\s*$", "", text, flags=re.MULTILINE + ) ## --- *** ___ rules + text = re.sub(r"<[^>]*>", "", text) ## tags - #Replace special sequences - text = re.sub(r'-->', ', ', text) ## --> - text = re.sub(r'<--', ', ', text) ## <-- - text = re.sub(r'=>', ', ', text) ## => - text = re.sub(r'::', ' ', text) ## :: + # Remove non-ASCII (emojis, special Unicode) - #Normalize whitespace - text = re.sub(r'\s+', ' ', text) + text = re.sub(r"[^\x20-\x7E\n\r\t]", "", text) + + # Replace special sequences + + text = re.sub(r"-->", ", ", text) ## --> + text = re.sub(r"<--", ", ", text) ## <-- + text = re.sub(r"=>", ", ", text) ## => + text = re.sub(r"::", " ", text) ## :: + + # Normalize whitespace + + text = re.sub(r"\s+", " ", text) text = text.strip() return text diff --git a/application/worker.py b/application/worker.py index f17e1537..3f957527 100755 --- a/application/worker.py +++ b/application/worker.py @@ -165,7 +165,7 @@ def run_agent_logic(agent_config, input_data): agent_type, endpoint="webhook", llm_name=settings.LLM_PROVIDER, - gpt_model=settings.LLM_NAME, + model_id=settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key, prompt=prompt, @@ -180,7 +180,7 @@ def run_agent_logic(agent_config, input_data): prompt=prompt, chunks=chunks, token_limit=settings.DEFAULT_MAX_HISTORY, - gpt_model=settings.LLM_NAME, + model_id=settings.LLM_NAME, user_api_key=user_api_key, decoded_token=decoded_token, ) diff --git a/frontend/src/Hero.tsx b/frontend/src/Hero.tsx index 9b17c10f..8b08430b 100644 --- a/frontend/src/Hero.tsx +++ b/frontend/src/Hero.tsx @@ -1,6 +1,8 @@ -import DocsGPT3 from './assets/cute_docsgpt3.svg'; import { useTranslation } from 'react-i18next'; +import DocsGPT3 from './assets/cute_docsgpt3.svg'; +import DropdownModel from './components/DropdownModel'; + export default function Hero({ handleQuestion, }: { @@ -26,6 +28,10 @@ export default function Hero({ DocsGPT docsgpt + {/* Model Selector */} +
+ +
{/* Demo Buttons Section */} @@ -38,7 +44,7 @@ export default function Hero({ + setIsModelsPopupOpen(false)} + anchorRef={modelAnchorButtonRef} + options={availableModels.map((model) => ({ + id: model.id, + label: model.display_name, + }))} + selectedIds={selectedModelIds} + onSelectionChange={(newSelectedIds: Set) => + setSelectedModelIds( + new Set(Array.from(newSelectedIds).map(String)), + ) + } + title={t('agents.form.modelsPopup.title')} + searchPlaceholder={t( + 'agents.form.modelsPopup.searchPlaceholder', + )} + noOptionsMessage={t('agents.form.modelsPopup.noOptionsMessage')} + /> + {selectedModelIds.size > 0 && ( +
+ + selectedModelIds.has(m.id)) + .map((m) => ({ + label: m.display_name, + value: m.id, + }))} + selectedValue={ + availableModels.find( + (m) => m.id === agent.default_model_id, + )?.display_name || null + } + onSelect={(option: { label: string; value: string }) => + setAgent({ ...agent, default_model_id: option.value }) + } + size="w-full" + rounded="3xl" + border="border" + buttonClassName="bg-white dark:bg-[#222327] border-silver dark:border-[#7E7E7E]" + optionsClassName="bg-white dark:bg-[#383838] border-silver dark:border-[#7E7E7E]" + placeholder={t( + 'agents.form.placeholders.selectDefaultModel', + )} + placeholderClassName="text-gray-400 dark:text-silver" + contentSize="text-sm" + /> +
+ )} + +