fix: token calc (#2285)

This commit is contained in:
Alex
2026-02-20 17:37:47 +00:00
committed by GitHub
parent 444abb8283
commit 1a2104f474
15 changed files with 543 additions and 30 deletions

View File

@@ -23,6 +23,7 @@ class BaseAgent(ABC):
llm_name: str,
model_id: str,
api_key: str,
agent_id: Optional[str] = None,
user_api_key: Optional[str] = None,
prompt: str = "",
chat_history: Optional[List[Dict]] = None,
@@ -40,6 +41,7 @@ class BaseAgent(ABC):
self.llm_name = llm_name
self.model_id = model_id
self.api_key = api_key
self.agent_id = agent_id
self.user_api_key = user_api_key
self.prompt = prompt
self.decoded_token = decoded_token or {}
@@ -54,6 +56,7 @@ class BaseAgent(ABC):
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
self.retrieved_docs = retrieved_docs or []
self.llm_handler = LLMHandlerCreator.create_handler(
@@ -263,6 +266,11 @@ class BaseAgent(ABC):
tool_config=tool_config,
user_id=self.user,
)
resolved_arguments = (
{"query_params": query_params, "headers": headers, "body": body}
if tool_data["name"] == "api_tool"
else parameters
)
if tool_data["name"] == "api_tool":
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
@@ -292,11 +300,19 @@ class BaseAgent(ABC):
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
if artifact_id:
tool_call_data["artifact_id"] = artifact_id
result_full = str(result)
tool_call_data["resolved_arguments"] = resolved_arguments
tool_call_data["result_full"] = result_full
tool_call_data["result"] = (
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
)
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
stream_tool_call_data = {
key: value
for key, value in tool_call_data.items()
if key not in {"result_full", "resolved_arguments"}
}
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
@@ -304,7 +320,11 @@ class BaseAgent(ABC):
def _get_truncated_tool_calls(self):
return [
{
**tool_call,
"tool_name": tool_call.get("tool_name"),
"call_id": tool_call.get("call_id"),
"action_name": tool_call.get("action_name"),
"arguments": tool_call.get("arguments"),
"artifact_id": tool_call.get("artifact_id"),
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
@@ -576,6 +596,9 @@ class BaseAgent(ABC):
self._validate_context_size(messages)
gen_kwargs = {"model": self.model_id, "messages": messages}
if self.attachments:
# Usage accounting only; stripped before provider invocation.
gen_kwargs["_usage_attachments"] = self.attachments
if (
hasattr(self.llm, "_supports_tools")