mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-23 04:41:47 +00:00
fix: token calc (#2285)
This commit is contained in:
@@ -23,6 +23,7 @@ class BaseAgent(ABC):
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
agent_id: Optional[str] = None,
|
||||
user_api_key: Optional[str] = None,
|
||||
prompt: str = "",
|
||||
chat_history: Optional[List[Dict]] = None,
|
||||
@@ -40,6 +41,7 @@ class BaseAgent(ABC):
|
||||
self.llm_name = llm_name
|
||||
self.model_id = model_id
|
||||
self.api_key = api_key
|
||||
self.agent_id = agent_id
|
||||
self.user_api_key = user_api_key
|
||||
self.prompt = prompt
|
||||
self.decoded_token = decoded_token or {}
|
||||
@@ -54,6 +56,7 @@ class BaseAgent(ABC):
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
self.retrieved_docs = retrieved_docs or []
|
||||
self.llm_handler = LLMHandlerCreator.create_handler(
|
||||
@@ -263,6 +266,11 @@ class BaseAgent(ABC):
|
||||
tool_config=tool_config,
|
||||
user_id=self.user,
|
||||
)
|
||||
resolved_arguments = (
|
||||
{"query_params": query_params, "headers": headers, "body": body}
|
||||
if tool_data["name"] == "api_tool"
|
||||
else parameters
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
logger.debug(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
@@ -292,11 +300,19 @@ class BaseAgent(ABC):
|
||||
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
|
||||
if artifact_id:
|
||||
tool_call_data["artifact_id"] = artifact_id
|
||||
result_full = str(result)
|
||||
tool_call_data["resolved_arguments"] = resolved_arguments
|
||||
tool_call_data["result_full"] = result_full
|
||||
tool_call_data["result"] = (
|
||||
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
|
||||
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
|
||||
)
|
||||
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
|
||||
stream_tool_call_data = {
|
||||
key: value
|
||||
for key, value in tool_call_data.items()
|
||||
if key not in {"result_full", "resolved_arguments"}
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
|
||||
return result, call_id
|
||||
@@ -304,7 +320,11 @@ class BaseAgent(ABC):
|
||||
def _get_truncated_tool_calls(self):
|
||||
return [
|
||||
{
|
||||
**tool_call,
|
||||
"tool_name": tool_call.get("tool_name"),
|
||||
"call_id": tool_call.get("call_id"),
|
||||
"action_name": tool_call.get("action_name"),
|
||||
"arguments": tool_call.get("arguments"),
|
||||
"artifact_id": tool_call.get("artifact_id"),
|
||||
"result": (
|
||||
f"{str(tool_call['result'])[:50]}..."
|
||||
if len(str(tool_call["result"])) > 50
|
||||
@@ -576,6 +596,9 @@ class BaseAgent(ABC):
|
||||
self._validate_context_size(messages)
|
||||
|
||||
gen_kwargs = {"model": self.model_id, "messages": messages}
|
||||
if self.attachments:
|
||||
# Usage accounting only; stripped before provider invocation.
|
||||
gen_kwargs["_usage_attachments"] = self.attachments
|
||||
|
||||
if (
|
||||
hasattr(self.llm, "_supports_tools")
|
||||
|
||||
Reference in New Issue
Block a user