diff --git a/application/agents/base.py b/application/agents/base.py index a4bbd001..f48418b3 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -256,12 +256,21 @@ class BaseAgent(ABC): return retrieved_data def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None): - resp = self.llm.gen_stream( - model=self.gpt_model, messages=messages, tools=self.tools - ) + gen_kwargs = {"model": self.gpt_model, "messages": messages} + + if ( + hasattr(self.llm, "_supports_tools") + and self.llm._supports_tools + and self.tools + ): + gen_kwargs["tools"] = self.tools + + resp = self.llm.gen_stream(**gen_kwargs) + if log_context: data = build_stack_data(self.llm, exclude_attributes=["client"]) log_context.stacks.append({"component": "llm", "data": data}) + return resp def _llm_handler( diff --git a/application/llm/base.py b/application/llm/base.py index 0607159d..b145816d 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -1,53 +1,118 @@ +import logging from abc import ABC, abstractmethod from application.cache import gen_cache, stream_cache + +from application.core.settings import settings from application.usage import gen_token_usage, stream_token_usage +logger = logging.getLogger(__name__) + class BaseLLM(ABC): - def __init__(self, decoded_token=None): + def __init__( + self, + decoded_token=None, + ): self.decoded_token = decoded_token self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} + self.fallback_provider = settings.FALLBACK_LLM_PROVIDER + self.fallback_model_name = settings.FALLBACK_LLM_NAME + self.fallback_llm_api_key = settings.FALLBACK_LLM_API_KEY + self._fallback_llm = None - def _apply_decorator(self, method, decorators, *args, **kwargs): - for decorator in decorators: - method = decorator(method) - return method(self, *args, **kwargs) + @property + def fallback_llm(self): + """Lazy-loaded fallback LLM instance.""" + if ( + self._fallback_llm is None + and self.fallback_provider + and self.fallback_model_name + ): + try: + from llm.llm_creator import LLMCreator + + self._fallback_llm = LLMCreator( + self.fallback_provider, + self.fallback_llm_api_key, + None, + self.decoded_token, + ) + except Exception as e: + logger.error( + f"Failed to initialize fallback LLM: {str(e)}", exc_info=True + ) + return self._fallback_llm + + def _execute_with_fallback( + self, method_name: str, decorators: list, *args, **kwargs + ): + """ + Unified method execution with fallback support. + + Args: + method_name: Name of the raw method ('_raw_gen' or '_raw_gen_stream') + decorators: List of decorators to apply + *args: Positional arguments + **kwargs: Keyword arguments + """ + + def decorated_method(): + method = getattr(self, method_name) + for decorator in decorators: + method = decorator(method) + return method(self, *args, **kwargs) + + try: + return decorated_method() + except Exception as e: + if not self.fallback_llm: + logger.error(f"Primary LLM failed and no fallback available: {str(e)}") + raise + logger.warning( + f"Falling back to {self.fallback_provider}/{self.fallback_model_name}. Error: {str(e)}" + ) + # Retry with fallback (without decorators for accurate token tracking) + + fallback_method = getattr( + self.fallback_llm, method_name.replace("_raw_", "") + ) + return fallback_method(*args, **kwargs) + + def gen(self, model, messages, stream=False, tools=None, *args, **kwargs): + decorators = [gen_token_usage, gen_cache] + return self._execute_with_fallback( + "_raw_gen", + decorators, + model=model, + messages=messages, + stream=stream, + tools=tools, + *args, + **kwargs, + ) + + def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs): + decorators = [stream_cache, stream_token_usage] + return self._execute_with_fallback( + "_raw_gen_stream", + decorators, + model=model, + messages=messages, + stream=stream, + tools=tools, + *args, + **kwargs, + ) @abstractmethod def _raw_gen(self, model, messages, stream, tools, *args, **kwargs): pass - def gen(self, model, messages, stream=False, tools=None, *args, **kwargs): - decorators = [gen_token_usage, gen_cache] - return self._apply_decorator( - self._raw_gen, - decorators=decorators, - model=model, - messages=messages, - stream=stream, - tools=tools, - *args, - **kwargs - ) - @abstractmethod def _raw_gen_stream(self, model, messages, stream, *args, **kwargs): pass - def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs): - decorators = [stream_cache, stream_token_usage] - return self._apply_decorator( - self._raw_gen_stream, - decorators=decorators, - model=model, - messages=messages, - stream=stream, - tools=tools, - *args, - **kwargs - ) - def supports_tools(self): return hasattr(self, "_supports_tools") and callable( getattr(self, "_supports_tools") @@ -55,11 +120,11 @@ class BaseLLM(ABC): def _supports_tools(self): raise NotImplementedError("Subclass must implement _supports_tools method") - + def get_supported_attachment_types(self): """ Return a list of MIME types supported by this LLM for file uploads. - + Returns: list: List of supported MIME types """