fix: decorators + client error

This commit is contained in:
Siddhant Rai
2025-01-20 19:44:14 +05:30
parent a741388447
commit d441d5763f
4 changed files with 65 additions and 17 deletions

View File

@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from application.cache import gen_cache, stream_cache
from application.usage import gen_token_usage, stream_token_usage
from application.cache import stream_cache, gen_cache
class BaseLLM(ABC):
@@ -18,18 +19,38 @@ class BaseLLM(ABC):
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
decorators = [gen_token_usage, gen_cache]
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, tools=tools, *args, **kwargs)
return self._apply_decorator(
self._raw_gen,
decorators=decorators,
model=model,
messages=messages,
stream=stream,
tools=tools,
*args,
**kwargs
)
@abstractmethod
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
pass
def gen_stream(self, model, messages, stream=True, *args, **kwargs):
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
decorators = [stream_cache, stream_token_usage]
return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
return self._apply_decorator(
self._raw_gen_stream,
decorators=decorators,
model=model,
messages=messages,
stream=stream,
tools=tools,
*args,
**kwargs
)
def supports_tools(self):
return hasattr(self, '_supports_tools') and callable(getattr(self, '_supports_tools'))
return hasattr(self, "_supports_tools") and callable(
getattr(self, "_supports_tools")
)
def _supports_tools(self):
raise NotImplementedError("Subclass must implement _supports_tools method")
raise NotImplementedError("Subclass must implement _supports_tools method")

View File

@@ -8,7 +8,8 @@ from application.llm.base import BaseLLM
class GoogleLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = genai.Client(api_key=api_key)
self.api_key = api_key
self.user_api_key = user_api_key
def _clean_messages_google(self, messages):
cleaned_messages = []
@@ -16,11 +17,32 @@ class GoogleLLM(BaseLLM):
role = message.get("role")
content = message.get("content")
parts = []
if role and content is not None:
if isinstance(content, str):
parts = [types.Part.from_text(content)]
elif isinstance(content, list):
parts = content
for item in content:
if "text" in item:
parts.append(types.Part.from_text(item["text"]))
elif "function_call" in item:
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=item["function_call"]["args"],
)
)
elif "function_response" in item:
parts.append(
types.Part.from_function_response(
name=item["function_response"]["name"],
response=item["function_response"]["response"],
)
)
else:
raise ValueError(
f"Unexpected content dictionary format:{item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
@@ -67,7 +89,7 @@ class GoogleLLM(BaseLLM):
formatting="openai",
**kwargs,
):
client = self.client
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
messages = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
@@ -100,9 +122,9 @@ class GoogleLLM(BaseLLM):
formatting="openai",
**kwargs,
):
client = self.client
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
cleaned_messages = self._clean_messages_google(messages)
messages = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
@@ -114,7 +136,7 @@ class GoogleLLM(BaseLLM):
response = client.models.generate_content_stream(
model=model,
contents=cleaned_messages,
contents=messages,
config=config,
)
for chunk in response: