mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-19 19:01:39 +00:00
fix: openai compatable with llama and gemini
This commit is contained in:
@@ -72,9 +72,9 @@ class OpenAILLMHandler(LLMHandler):
|
|||||||
while True:
|
while True:
|
||||||
tool_calls = {}
|
tool_calls = {}
|
||||||
for chunk in resp:
|
for chunk in resp:
|
||||||
if isinstance(chunk, str):
|
if isinstance(chunk, str) and len(chunk) > 0:
|
||||||
return
|
return
|
||||||
else:
|
elif hasattr(chunk, "delta"):
|
||||||
chunk_delta = chunk.delta
|
chunk_delta = chunk.delta
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -113,6 +113,8 @@ class OpenAILLMHandler(LLMHandler):
|
|||||||
tool_response, call_id = agent._execute_tool_action(
|
tool_response, call_id = agent._execute_tool_action(
|
||||||
tools_dict, call
|
tools_dict, call
|
||||||
)
|
)
|
||||||
|
if isinstance(call["function"]["arguments"], str):
|
||||||
|
call["function"]["arguments"] = json.loads(call["function"]["arguments"])
|
||||||
|
|
||||||
function_call_dict = {
|
function_call_dict = {
|
||||||
"function_call": {
|
"function_call": {
|
||||||
@@ -156,6 +158,8 @@ class OpenAILLMHandler(LLMHandler):
|
|||||||
and chunk.finish_reason == "stop"
|
and chunk.finish_reason == "stop"
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
elif isinstance(chunk, str) and len(chunk) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
resp = agent.llm.gen_stream(
|
resp = agent.llm.gen_stream(
|
||||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||||
|
|||||||
@@ -1,34 +1,132 @@
|
|||||||
from application.llm.base import BaseLLM
|
|
||||||
import json
|
import json
|
||||||
import requests
|
import sys
|
||||||
|
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.llm.base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
class DocsGPTAPILLM(BaseLLM):
|
class DocsGPTAPILLM(BaseLLM):
|
||||||
|
|
||||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.api_key = api_key
|
self.client = OpenAI(api_key="sk-docsgpt-public", base_url="https://oai.arc53.com")
|
||||||
self.user_api_key = user_api_key
|
self.user_api_key = user_api_key
|
||||||
self.endpoint = "https://llm.arc53.com"
|
self.api_key = api_key
|
||||||
|
|
||||||
def _raw_gen(self, baseself, model, messages, stream=False, *args, **kwargs):
|
def _clean_messages_openai(self, messages):
|
||||||
response = requests.post(
|
cleaned_messages = []
|
||||||
f"{self.endpoint}/answer", json={"messages": messages, "max_new_tokens": 30}
|
for message in messages:
|
||||||
)
|
role = message.get("role")
|
||||||
response_clean = response.json()["a"].replace("###", "")
|
content = message.get("content")
|
||||||
|
|
||||||
return response_clean
|
if role == "model":
|
||||||
|
role = "assistant"
|
||||||
|
|
||||||
def _raw_gen_stream(self, baseself, model, messages, stream=True, *args, **kwargs):
|
if role and content is not None:
|
||||||
response = requests.post(
|
if isinstance(content, str):
|
||||||
f"{self.endpoint}/stream",
|
cleaned_messages.append({"role": role, "content": content})
|
||||||
json={"messages": messages, "max_new_tokens": 256},
|
elif isinstance(content, list):
|
||||||
stream=True,
|
for item in content:
|
||||||
)
|
if "text" in item:
|
||||||
|
cleaned_messages.append(
|
||||||
|
{"role": role, "content": item["text"]}
|
||||||
|
)
|
||||||
|
elif "function_call" in item:
|
||||||
|
tool_call = {
|
||||||
|
"id": item["function_call"]["call_id"],
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": item["function_call"]["name"],
|
||||||
|
"arguments": json.dumps(
|
||||||
|
item["function_call"]["args"]
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cleaned_messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [tool_call],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif "function_response" in item:
|
||||||
|
cleaned_messages.append(
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": item["function_response"][
|
||||||
|
"call_id"
|
||||||
|
],
|
||||||
|
"content": json.dumps(
|
||||||
|
item["function_response"]["response"]["result"]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected content dictionary format: {item}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||||
|
|
||||||
for line in response.iter_lines():
|
return cleaned_messages
|
||||||
if line:
|
|
||||||
data_str = line.decode("utf-8")
|
def _raw_gen(
|
||||||
if data_str.startswith("data: "):
|
self,
|
||||||
data = json.loads(data_str[6:])
|
baseself,
|
||||||
yield data["a"]
|
model,
|
||||||
|
messages,
|
||||||
|
stream=False,
|
||||||
|
tools=None,
|
||||||
|
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
messages = self._clean_messages_openai(messages)
|
||||||
|
if tools:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model="docsgpt",
|
||||||
|
messages=messages,
|
||||||
|
stream=stream,
|
||||||
|
tools=tools,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return response.choices[0]
|
||||||
|
else:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model="docsgpt", messages=messages, stream=stream, **kwargs
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
def _raw_gen_stream(
|
||||||
|
self,
|
||||||
|
baseself,
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
stream=True,
|
||||||
|
tools=None,
|
||||||
|
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
messages = self._clean_messages_openai(messages)
|
||||||
|
if tools:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model="docsgpt",
|
||||||
|
messages=messages,
|
||||||
|
stream=stream,
|
||||||
|
tools=tools,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model="docsgpt", messages=messages, stream=stream, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
for line in response:
|
||||||
|
if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0:
|
||||||
|
yield line.choices[0].delta.content
|
||||||
|
elif len(line.choices) > 0:
|
||||||
|
yield line.choices[0]
|
||||||
|
|
||||||
|
def _supports_tools(self):
|
||||||
|
return True
|
||||||
@@ -125,9 +125,9 @@ class OpenAILLM(BaseLLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for line in response:
|
for line in response:
|
||||||
if line.choices[0].delta.content is not None:
|
if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0:
|
||||||
yield line.choices[0].delta.content
|
yield line.choices[0].delta.content
|
||||||
else:
|
elif len(line.choices) > 0:
|
||||||
yield line.choices[0]
|
yield line.choices[0]
|
||||||
|
|
||||||
def _supports_tools(self):
|
def _supports_tools(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user