fix: decorators + client error

This commit is contained in:
Siddhant Rai
2025-01-20 19:44:14 +05:30
parent a741388447
commit d441d5763f
4 changed files with 65 additions and 17 deletions

View File

@@ -8,7 +8,8 @@ from application.llm.base import BaseLLM
class GoogleLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = genai.Client(api_key=api_key)
self.api_key = api_key
self.user_api_key = user_api_key
def _clean_messages_google(self, messages):
cleaned_messages = []
@@ -16,11 +17,32 @@ class GoogleLLM(BaseLLM):
role = message.get("role")
content = message.get("content")
parts = []
if role and content is not None:
if isinstance(content, str):
parts = [types.Part.from_text(content)]
elif isinstance(content, list):
parts = content
for item in content:
if "text" in item:
parts.append(types.Part.from_text(item["text"]))
elif "function_call" in item:
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=item["function_call"]["args"],
)
)
elif "function_response" in item:
parts.append(
types.Part.from_function_response(
name=item["function_response"]["name"],
response=item["function_response"]["response"],
)
)
else:
raise ValueError(
f"Unexpected content dictionary format:{item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
@@ -67,7 +89,7 @@ class GoogleLLM(BaseLLM):
formatting="openai",
**kwargs,
):
client = self.client
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
messages = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
@@ -100,9 +122,9 @@ class GoogleLLM(BaseLLM):
formatting="openai",
**kwargs,
):
client = self.client
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
cleaned_messages = self._clean_messages_google(messages)
messages = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
@@ -114,7 +136,7 @@ class GoogleLLM(BaseLLM):
response = client.models.generate_content_stream(
model=model,
contents=cleaned_messages,
contents=messages,
config=config,
)
for chunk in response: