diff --git a/application/cache.py b/application/cache.py index 76b594c9..72bab2b9 100644 --- a/application/cache.py +++ b/application/cache.py @@ -68,7 +68,7 @@ def gen_cache(func): def stream_cache(func): - def wrapper(self, model, messages, stream, *args, **kwargs): + def wrapper(self, model, messages, stream, tools=None, *args, **kwargs): cache_key = gen_cache_key(messages) logger.info(f"Stream cache key: {cache_key}") diff --git a/application/llm/base.py b/application/llm/base.py index b9b0e524..e687e567 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod + +from application.cache import gen_cache, stream_cache from application.usage import gen_token_usage, stream_token_usage -from application.cache import stream_cache, gen_cache class BaseLLM(ABC): @@ -18,18 +19,38 @@ class BaseLLM(ABC): def gen(self, model, messages, stream=False, tools=None, *args, **kwargs): decorators = [gen_token_usage, gen_cache] - return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, tools=tools, *args, **kwargs) + return self._apply_decorator( + self._raw_gen, + decorators=decorators, + model=model, + messages=messages, + stream=stream, + tools=tools, + *args, + **kwargs + ) @abstractmethod def _raw_gen_stream(self, model, messages, stream, *args, **kwargs): pass - def gen_stream(self, model, messages, stream=True, *args, **kwargs): + def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs): decorators = [stream_cache, stream_token_usage] - return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs) - + return self._apply_decorator( + self._raw_gen_stream, + decorators=decorators, + model=model, + messages=messages, + stream=stream, + tools=tools, + *args, + **kwargs + ) + def supports_tools(self): - return hasattr(self, '_supports_tools') and callable(getattr(self, '_supports_tools')) + return hasattr(self, "_supports_tools") and callable( + getattr(self, "_supports_tools") + ) def _supports_tools(self): - raise NotImplementedError("Subclass must implement _supports_tools method") \ No newline at end of file + raise NotImplementedError("Subclass must implement _supports_tools method") diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index 9ebc0a5e..7e67a4cd 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -8,7 +8,8 @@ from application.llm.base import BaseLLM class GoogleLLM(BaseLLM): def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): super().__init__(*args, **kwargs) - self.client = genai.Client(api_key=api_key) + self.api_key = api_key + self.user_api_key = user_api_key def _clean_messages_google(self, messages): cleaned_messages = [] @@ -16,11 +17,32 @@ class GoogleLLM(BaseLLM): role = message.get("role") content = message.get("content") + parts = [] if role and content is not None: if isinstance(content, str): parts = [types.Part.from_text(content)] elif isinstance(content, list): - parts = content + for item in content: + if "text" in item: + parts.append(types.Part.from_text(item["text"])) + elif "function_call" in item: + parts.append( + types.Part.from_function_call( + name=item["function_call"]["name"], + args=item["function_call"]["args"], + ) + ) + elif "function_response" in item: + parts.append( + types.Part.from_function_response( + name=item["function_response"]["name"], + response=item["function_response"]["response"], + ) + ) + else: + raise ValueError( + f"Unexpected content dictionary format:{item}" + ) else: raise ValueError(f"Unexpected content type: {type(content)}") @@ -67,7 +89,7 @@ class GoogleLLM(BaseLLM): formatting="openai", **kwargs, ): - client = self.client + client = genai.Client(api_key=self.api_key) if formatting == "openai": messages = self._clean_messages_google(messages) config = types.GenerateContentConfig() @@ -100,9 +122,9 @@ class GoogleLLM(BaseLLM): formatting="openai", **kwargs, ): - client = self.client + client = genai.Client(api_key=self.api_key) if formatting == "openai": - cleaned_messages = self._clean_messages_google(messages) + messages = self._clean_messages_google(messages) config = types.GenerateContentConfig() if messages[0].role == "system": config.system_instruction = messages[0].parts[0].text @@ -114,7 +136,7 @@ class GoogleLLM(BaseLLM): response = client.models.generate_content_stream( model=model, - contents=cleaned_messages, + contents=messages, config=config, ) for chunk in response: diff --git a/application/tools/llm_handler.py b/application/tools/llm_handler.py index 2383d3f5..cc7494c0 100644 --- a/application/tools/llm_handler.py +++ b/application/tools/llm_handler.py @@ -61,14 +61,19 @@ class GoogleLLMHandler(LLMHandler): tool_response, call_id = agent._execute_tool_action( tools_dict, part.function_call ) - function_response_part = types.Part.from_function_response( name=part.function_call.name, response={"result": tool_response}, ) - messages.append({"role": "model", "content": [part]}) + messages.append( - {"role": "tool", "content": [function_response_part]} + {"role": "model", "content": [part.to_json_dict()]} + ) + messages.append( + { + "role": "tool", + "content": [function_response_part.to_json_dict()], + } ) if (