mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-30 17:13:15 +00:00
fix: decorators + client error
This commit is contained in:
@@ -8,7 +8,8 @@ from application.llm.base import BaseLLM
|
||||
class GoogleLLM(BaseLLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client = genai.Client(api_key=api_key)
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _clean_messages_google(self, messages):
|
||||
cleaned_messages = []
|
||||
@@ -16,11 +17,32 @@ class GoogleLLM(BaseLLM):
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
parts = []
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
parts = [types.Part.from_text(content)]
|
||||
elif isinstance(content, list):
|
||||
parts = content
|
||||
for item in content:
|
||||
if "text" in item:
|
||||
parts.append(types.Part.from_text(item["text"]))
|
||||
elif "function_call" in item:
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=item["function_call"]["name"],
|
||||
args=item["function_call"]["args"],
|
||||
)
|
||||
)
|
||||
elif "function_response" in item:
|
||||
parts.append(
|
||||
types.Part.from_function_response(
|
||||
name=item["function_response"]["name"],
|
||||
response=item["function_response"]["response"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content dictionary format:{item}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
@@ -67,7 +89,7 @@ class GoogleLLM(BaseLLM):
|
||||
formatting="openai",
|
||||
**kwargs,
|
||||
):
|
||||
client = self.client
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
if formatting == "openai":
|
||||
messages = self._clean_messages_google(messages)
|
||||
config = types.GenerateContentConfig()
|
||||
@@ -100,9 +122,9 @@ class GoogleLLM(BaseLLM):
|
||||
formatting="openai",
|
||||
**kwargs,
|
||||
):
|
||||
client = self.client
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
if formatting == "openai":
|
||||
cleaned_messages = self._clean_messages_google(messages)
|
||||
messages = self._clean_messages_google(messages)
|
||||
config = types.GenerateContentConfig()
|
||||
if messages[0].role == "system":
|
||||
config.system_instruction = messages[0].parts[0].text
|
||||
@@ -114,7 +136,7 @@ class GoogleLLM(BaseLLM):
|
||||
|
||||
response = client.models.generate_content_stream(
|
||||
model=model,
|
||||
contents=cleaned_messages,
|
||||
contents=messages,
|
||||
config=config,
|
||||
)
|
||||
for chunk in response:
|
||||
|
||||
Reference in New Issue
Block a user