fix: token calc (#2285)

This commit is contained in:
Alex
2026-02-20 17:37:47 +00:00
committed by GitHub
parent 444abb8283
commit 1a2104f474
15 changed files with 543 additions and 30 deletions

View File

@@ -23,6 +23,7 @@ class BaseAgent(ABC):
llm_name: str,
model_id: str,
api_key: str,
agent_id: Optional[str] = None,
user_api_key: Optional[str] = None,
prompt: str = "",
chat_history: Optional[List[Dict]] = None,
@@ -40,6 +41,7 @@ class BaseAgent(ABC):
self.llm_name = llm_name
self.model_id = model_id
self.api_key = api_key
self.agent_id = agent_id
self.user_api_key = user_api_key
self.prompt = prompt
self.decoded_token = decoded_token or {}
@@ -54,6 +56,7 @@ class BaseAgent(ABC):
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
self.retrieved_docs = retrieved_docs or []
self.llm_handler = LLMHandlerCreator.create_handler(
@@ -263,6 +266,11 @@ class BaseAgent(ABC):
tool_config=tool_config,
user_id=self.user,
)
resolved_arguments = (
{"query_params": query_params, "headers": headers, "body": body}
if tool_data["name"] == "api_tool"
else parameters
)
if tool_data["name"] == "api_tool":
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
@@ -292,11 +300,19 @@ class BaseAgent(ABC):
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
if artifact_id:
tool_call_data["artifact_id"] = artifact_id
result_full = str(result)
tool_call_data["resolved_arguments"] = resolved_arguments
tool_call_data["result_full"] = result_full
tool_call_data["result"] = (
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
)
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
stream_tool_call_data = {
key: value
for key, value in tool_call_data.items()
if key not in {"result_full", "resolved_arguments"}
}
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
@@ -304,7 +320,11 @@ class BaseAgent(ABC):
def _get_truncated_tool_calls(self):
return [
{
**tool_call,
"tool_name": tool_call.get("tool_name"),
"call_id": tool_call.get("call_id"),
"action_name": tool_call.get("action_name"),
"arguments": tool_call.get("arguments"),
"artifact_id": tool_call.get("artifact_id"),
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
@@ -576,6 +596,9 @@ class BaseAgent(ABC):
self._validate_context_size(messages)
gen_kwargs = {"model": self.model_id, "messages": messages}
if self.attachments:
# Usage accounting only; stripped before provider invocation.
gen_kwargs["_usage_attachments"] = self.attachments
if (
hasattr(self.llm, "_supports_tools")

View File

@@ -42,6 +42,7 @@ class AnswerResource(Resource, BaseAnswerResource):
),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"agent_id": fields.String(required=False, description="Agent ID"),
"active_docs": fields.String(
required=False, description="Active documents"
),
@@ -100,6 +101,9 @@ class AnswerResource(Resource, BaseAnswerResource):
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
)
stream_result = self.process_response_stream(stream)

View File

@@ -46,6 +46,27 @@ class BaseAnswerResource:
return missing_fields
return None
@staticmethod
def _prepare_tool_calls_for_logging(
tool_calls: Optional[List[Dict[str, Any]]], max_chars: int = 10000
) -> List[Dict[str, Any]]:
if not tool_calls:
return []
prepared = []
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
prepared.append({"result": str(tool_call)[:max_chars]})
continue
item = dict(tool_call)
for key in ("result", "result_full"):
value = item.get(key)
if isinstance(value, str) and len(value) > max_chars:
item[key] = value[:max_chars]
prepared.append(item)
return prepared
def check_usage(self, agent_config: Dict) -> Optional[Response]:
"""Check if there is a usage limit and if it is exceeded
@@ -246,6 +267,7 @@ class BaseAnswerResource:
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
if should_save_conversation:
@@ -292,14 +314,20 @@ class BaseAnswerResource:
data = json.dumps(id_data)
yield f"data: {data}\n\n"
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
getattr(agent, "tool_calls", tool_calls) or tool_calls
)
log_data = {
"action": "stream_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"agent_id": agent_id,
"question": question,
"response": response_full,
"sources": source_log_docs,
"tool_calls": tool_calls_for_logging,
"attachments": attachment_ids,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
@@ -330,6 +358,7 @@ class BaseAnswerResource:
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
agent_id=agent_id,
)
self.conversation_service.save_conversation(
conversation_id,

View File

@@ -42,6 +42,7 @@ class StreamResource(Resource, BaseAnswerResource):
),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"agent_id": fields.String(required=False, description="Agent ID"),
"active_docs": fields.String(
required=False, description="Active documents"
),
@@ -107,7 +108,7 @@ class StreamResource(Resource, BaseAnswerResource):
index=data.get("index"),
should_save_conversation=data.get("save_conversation", True),
attachment_ids=data.get("attachments", []),
agent_id=data.get("agent_id"),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,

View File

@@ -134,6 +134,7 @@ class CompressionOrchestrator:
user_api_key=None,
decoded_token=decoded_token,
model_id=compression_model,
agent_id=conversation.get("agent_id"),
)
# Create compression service with DB update capability

View File

@@ -90,6 +90,7 @@ class StreamProcessor:
self.retriever_config = {}
self.is_shared_usage = False
self.shared_token = None
self.agent_id = self.data.get("agent_id")
self.model_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
@@ -355,10 +356,13 @@ class StreamProcessor:
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
agent_id, self.initial_user_id
)
self.agent_id = str(agent_id) if agent_id else None
api_key = self.data.get("api_key")
if api_key:
data_key = self._get_data_from_api_key(api_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
@@ -387,6 +391,8 @@ class StreamProcessor:
self.retriever_config["chunks"] = 2
elif self.agent_key:
data_key = self._get_data_from_api_key(self.agent_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
@@ -459,6 +465,7 @@ class StreamProcessor:
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
model_id=self.model_id,
user_api_key=self.agent_config["user_api_key"],
agent_id=self.agent_id,
decoded_token=self.decoded_token,
)
@@ -754,6 +761,7 @@ class StreamProcessor:
"llm_name": provider or settings.LLM_PROVIDER,
"model_id": self.model_id,
"api_key": system_api_key,
"agent_id": self.agent_id,
"user_api_key": self.agent_config["user_api_key"],
"prompt": rendered_prompt,
"chat_history": self.history,

View File

@@ -13,10 +13,12 @@ class BaseLLM(ABC):
def __init__(
self,
decoded_token=None,
agent_id=None,
model_id=None,
base_url=None,
):
self.decoded_token = decoded_token
self.agent_id = str(agent_id) if agent_id else None
self.model_id = model_id
self.base_url = base_url
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
@@ -33,9 +35,10 @@ class BaseLLM(ABC):
self._fallback_llm = LLMCreator.create_llm(
settings.FALLBACK_LLM_PROVIDER,
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
user_api_key=None,
user_api_key=getattr(self, "user_api_key", None),
decoded_token=self.decoded_token,
model_id=settings.FALLBACK_LLM_NAME,
agent_id=self.agent_id,
)
logger.info(
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"

View File

@@ -13,7 +13,7 @@ class GoogleLLM(BaseLLM):
def __init__(
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
):
super().__init__(*args, **kwargs)
super().__init__(decoded_token=decoded_token, *args, **kwargs)
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
self.user_api_key = user_api_key

View File

@@ -567,6 +567,7 @@ class LLMHandler(ABC):
getattr(agent, "user_api_key", None),
getattr(agent, "decoded_token", None),
model_id=compression_model,
agent_id=getattr(agent, "agent_id", None),
)
# Create service without DB persistence capability

View File

@@ -31,7 +31,15 @@ class LLMCreator:
@classmethod
def create_llm(
cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs
cls,
type,
api_key,
user_api_key,
decoded_token,
model_id=None,
agent_id=None,
*args,
**kwargs,
):
from application.core.model_utils import get_base_url_for_model
@@ -49,6 +57,7 @@ class LLMCreator:
user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
base_url=base_url,
*args,
**kwargs,

View File

@@ -18,6 +18,7 @@ class ClassicRAG(BaseRetriever):
doc_token_limit=50000,
model_id="docsgpt-local",
user_api_key=None,
agent_id=None,
llm_name=settings.LLM_PROVIDER,
api_key=settings.API_KEY,
decoded_token=None,
@@ -43,6 +44,7 @@ class ClassicRAG(BaseRetriever):
self.model_id = model_id
self.doc_token_limit = doc_token_limit
self.user_api_key = user_api_key
self.agent_id = agent_id
self.llm_name = llm_name
self.api_key = api_key
self.llm = LLMCreator.create_llm(
@@ -50,6 +52,7 @@ class ClassicRAG(BaseRetriever):
api_key=self.api_key,
user_api_key=self.user_api_key,
decoded_token=decoded_token,
agent_id=self.agent_id,
)
if "active_docs" in source and source["active_docs"] is not None:

View File

@@ -1,22 +1,104 @@
import sys
import logging
from datetime import datetime
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.utils import num_tokens_from_object_or_list, num_tokens_from_string
logger = logging.getLogger(__name__)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
usage_collection = db["token_usage"]
def update_token_usage(decoded_token, user_api_key, token_usage):
def _serialize_for_token_count(value):
"""Normalize payloads into token-countable primitives."""
if isinstance(value, str):
# Avoid counting large binary payloads in data URLs as text tokens.
if value.startswith("data:") and ";base64," in value:
return ""
return value
if value is None:
return ""
if isinstance(value, list):
return [_serialize_for_token_count(item) for item in value]
if isinstance(value, dict):
serialized = {}
for key, raw in value.items():
key_lower = str(key).lower()
# Skip raw binary-like fields; keep textual tool-call fields.
if key_lower in {"data", "base64", "image_data"} and isinstance(raw, str):
continue
if key_lower == "url" and isinstance(raw, str) and ";base64," in raw:
continue
serialized[key] = _serialize_for_token_count(raw)
return serialized
if hasattr(value, "model_dump") and callable(getattr(value, "model_dump")):
return _serialize_for_token_count(value.model_dump())
if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")):
return _serialize_for_token_count(value.to_dict())
if hasattr(value, "__dict__"):
return _serialize_for_token_count(vars(value))
return str(value)
def _count_tokens(value):
serialized = _serialize_for_token_count(value)
if isinstance(serialized, str):
return num_tokens_from_string(serialized)
return num_tokens_from_object_or_list(serialized)
def _count_prompt_tokens(messages, tools=None, usage_attachments=None, **kwargs):
prompt_tokens = 0
for message in messages or []:
if not isinstance(message, dict):
prompt_tokens += _count_tokens(message)
continue
prompt_tokens += _count_tokens(message.get("content"))
# Include tool-related message fields for providers that use OpenAI-native format.
prompt_tokens += _count_tokens(message.get("tool_calls"))
prompt_tokens += _count_tokens(message.get("tool_call_id"))
prompt_tokens += _count_tokens(message.get("function_call"))
prompt_tokens += _count_tokens(message.get("function_response"))
# Count tool schema payload passed to the model.
prompt_tokens += _count_tokens(tools)
# Count structured-output/schema payloads when provided.
prompt_tokens += _count_tokens(kwargs.get("response_format"))
prompt_tokens += _count_tokens(kwargs.get("response_schema"))
# Optional usage-only attachment context (not forwarded to provider).
prompt_tokens += _count_tokens(usage_attachments)
return prompt_tokens
def update_token_usage(decoded_token, user_api_key, token_usage, agent_id=None):
if "pytest" in sys.modules:
return
if decoded_token:
user_id = decoded_token["sub"]
else:
user_id = None
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
normalized_agent_id = str(agent_id) if agent_id else None
if not user_id and not user_api_key and not normalized_agent_id:
logger.warning(
"Skipping token usage insert: missing user_id, api_key, and agent_id"
)
return
usage_data = {
"user_id": user_id,
"api_key": user_api_key,
@@ -24,24 +106,31 @@ def update_token_usage(decoded_token, user_api_key, token_usage):
"generated_tokens": token_usage["generated_tokens"],
"timestamp": datetime.now(),
}
if normalized_agent_id:
usage_data["agent_id"] = normalized_agent_id
usage_collection.insert_one(usage_data)
def gen_token_usage(func):
def wrapper(self, model, messages, stream, tools, **kwargs):
for message in messages:
if message["content"]:
self.token_usage["prompt_tokens"] += num_tokens_from_string(
message["content"]
)
usage_attachments = kwargs.pop("_usage_attachments", None)
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
call_usage["prompt_tokens"] += _count_prompt_tokens(
messages,
tools=tools,
usage_attachments=usage_attachments,
**kwargs,
)
result = func(self, model, messages, stream, tools, **kwargs)
if isinstance(result, str):
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
else:
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(
result
)
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
call_usage["generated_tokens"] += _count_tokens(result)
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
update_token_usage(
self.decoded_token,
self.user_api_key,
call_usage,
getattr(self, "agent_id", None),
)
return result
return wrapper
@@ -49,17 +138,28 @@ def gen_token_usage(func):
def stream_token_usage(func):
def wrapper(self, model, messages, stream, tools, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += num_tokens_from_string(
message["content"]
)
usage_attachments = kwargs.pop("_usage_attachments", None)
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
call_usage["prompt_tokens"] += _count_prompt_tokens(
messages,
tools=tools,
usage_attachments=usage_attachments,
**kwargs,
)
batch = []
result = func(self, model, messages, stream, tools, **kwargs)
for r in result:
batch.append(r)
yield r
for line in batch:
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
call_usage["generated_tokens"] += _count_tokens(line)
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
update_token_usage(
self.decoded_token,
self.user_api_key,
call_usage,
getattr(self, "agent_id", None),
)
return wrapper

View File

@@ -322,6 +322,7 @@ def run_agent_logic(agent_config, input_data):
chunks = int(agent_config.get("chunks", 2))
prompt_id = agent_config.get("prompt_id", "default")
user_api_key = agent_config["key"]
agent_id = str(agent_config.get("_id")) if agent_config.get("_id") else None
agent_type = agent_config.get("agent_type", "classic")
decoded_token = {"sub": agent_config.get("user")}
json_schema = agent_config.get("json_schema")
@@ -352,6 +353,7 @@ def run_agent_logic(agent_config, input_data):
doc_token_limit=doc_token_limit,
model_id=model_id,
user_api_key=user_api_key,
agent_id=agent_id,
decoded_token=decoded_token,
)
@@ -370,6 +372,7 @@ def run_agent_logic(agent_config, input_data):
llm_name=provider or settings.LLM_PROVIDER,
model_id=model_id,
api_key=system_api_key,
agent_id=agent_id,
user_api_key=user_api_key,
prompt=prompt,
chat_history=[],

View File

@@ -199,6 +199,7 @@ class TestStreamProcessorAgentConfiguration:
try:
processor._configure_agent()
assert processor.agent_config is not None
assert processor.agent_id == str(agent_id)
except Exception as e:
assert "Invalid API Key" in str(e)
@@ -211,6 +212,7 @@ class TestStreamProcessorAgentConfiguration:
processor._configure_agent()
assert isinstance(processor.agent_config, dict)
assert processor.agent_id is None
@pytest.mark.unit

326
tests/test_usage.py Normal file
View File

@@ -0,0 +1,326 @@
import sys
import pytest
from application.usage import (
_count_tokens,
gen_token_usage,
stream_token_usage,
update_token_usage,
)
@pytest.mark.unit
def test_count_tokens_includes_tool_call_payloads():
payload = [
{
"function_call": {
"name": "search_docs",
"args": {"query": "pricing limits"},
"call_id": "call_1",
}
},
{
"function_response": {
"name": "search_docs",
"response": {"result": "Found 3 docs"},
"call_id": "call_1",
}
},
]
assert _count_tokens(payload) > 0
@pytest.mark.unit
def test_gen_token_usage_counts_structured_tool_content(monkeypatch):
captured = {}
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
captured["decoded_token"] = decoded_token
captured["user_api_key"] = user_api_key
captured["token_usage"] = token_usage.copy()
captured["agent_id"] = agent_id
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
class DummyLLM:
decoded_token = {"sub": "user_123"}
user_api_key = "api_key_123"
agent_id = "agent_123"
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
@gen_token_usage
def wrapped(self, model, messages, stream, tools, **kwargs):
_ = (model, messages, stream, tools, kwargs)
return {
"tool_calls": [
{"name": "read_webpage", "arguments": {"url": "https://example.com"}}
]
}
messages = [
{
"role": "assistant",
"content": [
{
"function_call": {
"name": "search_docs",
"args": {"query": "pricing"},
"call_id": "1",
}
}
],
},
{
"role": "tool",
"content": [
{
"function_response": {
"name": "search_docs",
"response": {"result": "Found docs"},
"call_id": "1",
}
}
],
},
]
llm = DummyLLM()
wrapped(llm, "gpt-4o", messages, False, None)
assert captured["decoded_token"] == {"sub": "user_123"}
assert captured["user_api_key"] == "api_key_123"
assert captured["agent_id"] == "agent_123"
assert captured["token_usage"]["prompt_tokens"] > 0
assert captured["token_usage"]["generated_tokens"] > 0
@pytest.mark.unit
def test_stream_token_usage_counts_tool_call_chunks(monkeypatch):
captured = {}
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
captured["token_usage"] = token_usage.copy()
captured["agent_id"] = agent_id
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
class ToolChunk:
def model_dump(self):
return {
"delta": {
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "get_weather",
"arguments": '{"location":"Seattle"}',
},
}
]
}
}
class DummyLLM:
decoded_token = {"sub": "user_123"}
user_api_key = "api_key_123"
agent_id = "agent_123"
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
@stream_token_usage
def wrapped(self, model, messages, stream, tools, **kwargs):
_ = (model, messages, stream, tools, kwargs)
yield ToolChunk()
yield "done"
messages = [
{
"role": "assistant",
"content": [
{
"function_call": {
"name": "get_weather",
"args": {"location": "Seattle"},
"call_id": "1",
}
}
],
}
]
llm = DummyLLM()
list(wrapped(llm, "gpt-4o", messages, True, None))
assert captured["agent_id"] == "agent_123"
assert captured["token_usage"]["prompt_tokens"] > 0
assert captured["token_usage"]["generated_tokens"] > 0
@pytest.mark.unit
def test_gen_token_usage_counts_tools_and_image_inputs(monkeypatch):
captured = []
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
_ = (decoded_token, user_api_key, agent_id)
captured.append(token_usage.copy())
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
class DummyLLM:
decoded_token = {"sub": "user_123"}
user_api_key = "api_key_123"
agent_id = "agent_123"
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
@gen_token_usage
def wrapped(self, model, messages, stream, tools, **kwargs):
_ = (model, messages, stream, tools, kwargs)
return "ok"
messages = [{"role": "user", "content": "What is in this image?"}]
tools_payload = [
{
"type": "function",
"function": {
"name": "describe_image",
"description": "Describe image content",
"parameters": {
"type": "object",
"properties": {"detail": {"type": "string"}},
},
},
}
]
usage_attachments = [
{
"mime_type": "image/png",
"path": "attachments/example.png",
"data": "abc123",
}
]
llm = DummyLLM()
wrapped(llm, "gpt-4o", messages, False, None)
wrapped(
llm,
"gpt-4o",
messages,
False,
tools_payload,
_usage_attachments=usage_attachments,
)
assert len(captured) == 2
assert captured[1]["prompt_tokens"] > captured[0]["prompt_tokens"]
@pytest.mark.unit
def test_stream_token_usage_counts_tools_and_image_inputs(monkeypatch):
captured = []
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
_ = (decoded_token, user_api_key, agent_id)
captured.append(token_usage.copy())
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
class DummyLLM:
decoded_token = {"sub": "user_123"}
user_api_key = "api_key_123"
agent_id = "agent_123"
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
@stream_token_usage
def wrapped(self, model, messages, stream, tools, **kwargs):
_ = (model, messages, stream, tools, kwargs)
yield "ok"
messages = [{"role": "user", "content": "What is in this image?"}]
tools_payload = [
{
"type": "function",
"function": {
"name": "describe_image",
"description": "Describe image content",
"parameters": {
"type": "object",
"properties": {"detail": {"type": "string"}},
},
},
}
]
usage_attachments = [
{
"mime_type": "image/png",
"path": "attachments/example.png",
"data": "abc123",
}
]
llm = DummyLLM()
list(wrapped(llm, "gpt-4o", messages, True, None))
list(
wrapped(
llm,
"gpt-4o",
messages,
True,
tools_payload,
_usage_attachments=usage_attachments,
)
)
assert len(captured) == 2
assert captured[1]["prompt_tokens"] > captured[0]["prompt_tokens"]
@pytest.mark.unit
def test_update_token_usage_inserts_with_agent_id_only(monkeypatch):
inserted_docs = []
class FakeCollection:
def insert_one(self, doc):
inserted_docs.append(doc)
modules_without_pytest = dict(sys.modules)
modules_without_pytest.pop("pytest", None)
monkeypatch.setattr("application.usage.sys.modules", modules_without_pytest)
monkeypatch.setattr("application.usage.usage_collection", FakeCollection())
update_token_usage(
decoded_token=None,
user_api_key=None,
token_usage={"prompt_tokens": 10, "generated_tokens": 5},
agent_id="agent_123",
)
assert len(inserted_docs) == 1
assert inserted_docs[0]["agent_id"] == "agent_123"
assert inserted_docs[0]["user_id"] is None
assert inserted_docs[0]["api_key"] is None
@pytest.mark.unit
def test_update_token_usage_skips_when_all_ids_missing(monkeypatch):
inserted_docs = []
class FakeCollection:
def insert_one(self, doc):
inserted_docs.append(doc)
modules_without_pytest = dict(sys.modules)
modules_without_pytest.pop("pytest", None)
monkeypatch.setattr("application.usage.sys.modules", modules_without_pytest)
monkeypatch.setattr("application.usage.usage_collection", FakeCollection())
update_token_usage(
decoded_token=None,
user_api_key=None,
token_usage={"prompt_tokens": 10, "generated_tokens": 5},
agent_id=None,
)
assert inserted_docs == []