mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
fix: decorators + client error
This commit is contained in:
@@ -68,7 +68,7 @@ def gen_cache(func):
|
||||
|
||||
|
||||
def stream_cache(func):
|
||||
def wrapper(self, model, messages, stream, *args, **kwargs):
|
||||
def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):
|
||||
cache_key = gen_cache_key(messages)
|
||||
logger.info(f"Stream cache key: {cache_key}")
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from application.cache import gen_cache, stream_cache
|
||||
from application.usage import gen_token_usage, stream_token_usage
|
||||
from application.cache import stream_cache, gen_cache
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
@@ -18,18 +19,38 @@ class BaseLLM(ABC):
|
||||
|
||||
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
|
||||
decorators = [gen_token_usage, gen_cache]
|
||||
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, tools=tools, *args, **kwargs)
|
||||
return self._apply_decorator(
|
||||
self._raw_gen,
|
||||
decorators=decorators,
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def gen_stream(self, model, messages, stream=True, *args, **kwargs):
|
||||
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
|
||||
decorators = [stream_cache, stream_token_usage]
|
||||
return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
|
||||
|
||||
return self._apply_decorator(
|
||||
self._raw_gen_stream,
|
||||
decorators=decorators,
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def supports_tools(self):
|
||||
return hasattr(self, '_supports_tools') and callable(getattr(self, '_supports_tools'))
|
||||
return hasattr(self, "_supports_tools") and callable(
|
||||
getattr(self, "_supports_tools")
|
||||
)
|
||||
|
||||
def _supports_tools(self):
|
||||
raise NotImplementedError("Subclass must implement _supports_tools method")
|
||||
raise NotImplementedError("Subclass must implement _supports_tools method")
|
||||
|
||||
@@ -8,7 +8,8 @@ from application.llm.base import BaseLLM
|
||||
class GoogleLLM(BaseLLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client = genai.Client(api_key=api_key)
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _clean_messages_google(self, messages):
|
||||
cleaned_messages = []
|
||||
@@ -16,11 +17,32 @@ class GoogleLLM(BaseLLM):
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
parts = []
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
parts = [types.Part.from_text(content)]
|
||||
elif isinstance(content, list):
|
||||
parts = content
|
||||
for item in content:
|
||||
if "text" in item:
|
||||
parts.append(types.Part.from_text(item["text"]))
|
||||
elif "function_call" in item:
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=item["function_call"]["name"],
|
||||
args=item["function_call"]["args"],
|
||||
)
|
||||
)
|
||||
elif "function_response" in item:
|
||||
parts.append(
|
||||
types.Part.from_function_response(
|
||||
name=item["function_response"]["name"],
|
||||
response=item["function_response"]["response"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content dictionary format:{item}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
@@ -67,7 +89,7 @@ class GoogleLLM(BaseLLM):
|
||||
formatting="openai",
|
||||
**kwargs,
|
||||
):
|
||||
client = self.client
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
if formatting == "openai":
|
||||
messages = self._clean_messages_google(messages)
|
||||
config = types.GenerateContentConfig()
|
||||
@@ -100,9 +122,9 @@ class GoogleLLM(BaseLLM):
|
||||
formatting="openai",
|
||||
**kwargs,
|
||||
):
|
||||
client = self.client
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
if formatting == "openai":
|
||||
cleaned_messages = self._clean_messages_google(messages)
|
||||
messages = self._clean_messages_google(messages)
|
||||
config = types.GenerateContentConfig()
|
||||
if messages[0].role == "system":
|
||||
config.system_instruction = messages[0].parts[0].text
|
||||
@@ -114,7 +136,7 @@ class GoogleLLM(BaseLLM):
|
||||
|
||||
response = client.models.generate_content_stream(
|
||||
model=model,
|
||||
contents=cleaned_messages,
|
||||
contents=messages,
|
||||
config=config,
|
||||
)
|
||||
for chunk in response:
|
||||
|
||||
@@ -61,14 +61,19 @@ class GoogleLLMHandler(LLMHandler):
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, part.function_call
|
||||
)
|
||||
|
||||
function_response_part = types.Part.from_function_response(
|
||||
name=part.function_call.name,
|
||||
response={"result": tool_response},
|
||||
)
|
||||
messages.append({"role": "model", "content": [part]})
|
||||
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_part]}
|
||||
{"role": "model", "content": [part.to_json_dict()]}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [function_response_part.to_json_dict()],
|
||||
}
|
||||
)
|
||||
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user