From 1a2104f474029f1fd4af769be130aba22702e2f7 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 20 Feb 2026 17:37:47 +0000 Subject: [PATCH] fix: token calc (#2285) --- application/agents/base.py | 29 +- application/api/answer/routes/answer.py | 4 + application/api/answer/routes/base.py | 29 ++ application/api/answer/routes/stream.py | 3 +- .../services/compression/orchestrator.py | 1 + .../api/answer/services/stream_processor.py | 8 + application/llm/base.py | 5 +- application/llm/google_ai.py | 2 +- application/llm/handlers/base.py | 1 + application/llm/llm_creator.py | 11 +- application/retriever/classic_rag.py | 3 + application/usage.py | 146 ++++++-- application/worker.py | 3 + .../answer/services/test_stream_processor.py | 2 + tests/test_usage.py | 326 ++++++++++++++++++ 15 files changed, 543 insertions(+), 30 deletions(-) create mode 100644 tests/test_usage.py diff --git a/application/agents/base.py b/application/agents/base.py index bdf7b100..49c84d33 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -23,6 +23,7 @@ class BaseAgent(ABC): llm_name: str, model_id: str, api_key: str, + agent_id: Optional[str] = None, user_api_key: Optional[str] = None, prompt: str = "", chat_history: Optional[List[Dict]] = None, @@ -40,6 +41,7 @@ class BaseAgent(ABC): self.llm_name = llm_name self.model_id = model_id self.api_key = api_key + self.agent_id = agent_id self.user_api_key = user_api_key self.prompt = prompt self.decoded_token = decoded_token or {} @@ -54,6 +56,7 @@ class BaseAgent(ABC): user_api_key=user_api_key, decoded_token=decoded_token, model_id=model_id, + agent_id=agent_id, ) self.retrieved_docs = retrieved_docs or [] self.llm_handler = LLMHandlerCreator.create_handler( @@ -263,6 +266,11 @@ class BaseAgent(ABC): tool_config=tool_config, user_id=self.user, ) + resolved_arguments = ( + {"query_params": query_params, "headers": headers, "body": body} + if tool_data["name"] == "api_tool" + else parameters + ) if tool_data["name"] == "api_tool": logger.debug( f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}" @@ -292,11 +300,19 @@ class BaseAgent(ABC): artifact_id = str(artifact_id).strip() if artifact_id is not None else "" if artifact_id: tool_call_data["artifact_id"] = artifact_id + result_full = str(result) + tool_call_data["resolved_arguments"] = resolved_arguments + tool_call_data["result_full"] = result_full tool_call_data["result"] = ( - f"{str(result)[:50]}..." if len(str(result)) > 50 else result + f"{result_full[:50]}..." if len(result_full) > 50 else result_full ) - yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}} + stream_tool_call_data = { + key: value + for key, value in tool_call_data.items() + if key not in {"result_full", "resolved_arguments"} + } + yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}} self.tool_calls.append(tool_call_data) return result, call_id @@ -304,7 +320,11 @@ class BaseAgent(ABC): def _get_truncated_tool_calls(self): return [ { - **tool_call, + "tool_name": tool_call.get("tool_name"), + "call_id": tool_call.get("call_id"), + "action_name": tool_call.get("action_name"), + "arguments": tool_call.get("arguments"), + "artifact_id": tool_call.get("artifact_id"), "result": ( f"{str(tool_call['result'])[:50]}..." if len(str(tool_call["result"])) > 50 @@ -576,6 +596,9 @@ class BaseAgent(ABC): self._validate_context_size(messages) gen_kwargs = {"model": self.model_id, "messages": messages} + if self.attachments: + # Usage accounting only; stripped before provider invocation. + gen_kwargs["_usage_attachments"] = self.attachments if ( hasattr(self.llm, "_supports_tools") diff --git a/application/api/answer/routes/answer.py b/application/api/answer/routes/answer.py index c24ebffc..b90ffa15 100644 --- a/application/api/answer/routes/answer.py +++ b/application/api/answer/routes/answer.py @@ -42,6 +42,7 @@ class AnswerResource(Resource, BaseAnswerResource): ), "retriever": fields.String(required=False, description="Retriever type"), "api_key": fields.String(required=False, description="API key"), + "agent_id": fields.String(required=False, description="Agent ID"), "active_docs": fields.String( required=False, description="Active documents" ), @@ -100,6 +101,9 @@ class AnswerResource(Resource, BaseAnswerResource): isNoneDoc=data.get("isNoneDoc"), index=None, should_save_conversation=data.get("save_conversation", True), + agent_id=processor.agent_id, + is_shared_usage=processor.is_shared_usage, + shared_token=processor.shared_token, model_id=processor.model_id, ) stream_result = self.process_response_stream(stream) diff --git a/application/api/answer/routes/base.py b/application/api/answer/routes/base.py index be112729..272033ae 100644 --- a/application/api/answer/routes/base.py +++ b/application/api/answer/routes/base.py @@ -46,6 +46,27 @@ class BaseAnswerResource: return missing_fields return None + @staticmethod + def _prepare_tool_calls_for_logging( + tool_calls: Optional[List[Dict[str, Any]]], max_chars: int = 10000 + ) -> List[Dict[str, Any]]: + if not tool_calls: + return [] + + prepared = [] + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + prepared.append({"result": str(tool_call)[:max_chars]}) + continue + + item = dict(tool_call) + for key in ("result", "result_full"): + value = item.get(key) + if isinstance(value, str) and len(value) > max_chars: + item[key] = value[:max_chars] + prepared.append(item) + return prepared + def check_usage(self, agent_config: Dict) -> Optional[Response]: """Check if there is a usage limit and if it is exceeded @@ -246,6 +267,7 @@ class BaseAnswerResource: user_api_key=user_api_key, decoded_token=decoded_token, model_id=model_id, + agent_id=agent_id, ) if should_save_conversation: @@ -292,14 +314,20 @@ class BaseAnswerResource: data = json.dumps(id_data) yield f"data: {data}\n\n" + tool_calls_for_logging = self._prepare_tool_calls_for_logging( + getattr(agent, "tool_calls", tool_calls) or tool_calls + ) + log_data = { "action": "stream_answer", "level": "info", "user": decoded_token.get("sub"), "api_key": user_api_key, + "agent_id": agent_id, "question": question, "response": response_full, "sources": source_log_docs, + "tool_calls": tool_calls_for_logging, "attachments": attachment_ids, "timestamp": datetime.datetime.now(datetime.timezone.utc), } @@ -330,6 +358,7 @@ class BaseAnswerResource: api_key=settings.API_KEY, user_api_key=user_api_key, decoded_token=decoded_token, + agent_id=agent_id, ) self.conversation_service.save_conversation( conversation_id, diff --git a/application/api/answer/routes/stream.py b/application/api/answer/routes/stream.py index 1c3f1778..d1a71b25 100644 --- a/application/api/answer/routes/stream.py +++ b/application/api/answer/routes/stream.py @@ -42,6 +42,7 @@ class StreamResource(Resource, BaseAnswerResource): ), "retriever": fields.String(required=False, description="Retriever type"), "api_key": fields.String(required=False, description="API key"), + "agent_id": fields.String(required=False, description="Agent ID"), "active_docs": fields.String( required=False, description="Active documents" ), @@ -107,7 +108,7 @@ class StreamResource(Resource, BaseAnswerResource): index=data.get("index"), should_save_conversation=data.get("save_conversation", True), attachment_ids=data.get("attachments", []), - agent_id=data.get("agent_id"), + agent_id=processor.agent_id, is_shared_usage=processor.is_shared_usage, shared_token=processor.shared_token, model_id=processor.model_id, diff --git a/application/api/answer/services/compression/orchestrator.py b/application/api/answer/services/compression/orchestrator.py index 797a66d4..11a9032c 100644 --- a/application/api/answer/services/compression/orchestrator.py +++ b/application/api/answer/services/compression/orchestrator.py @@ -134,6 +134,7 @@ class CompressionOrchestrator: user_api_key=None, decoded_token=decoded_token, model_id=compression_model, + agent_id=conversation.get("agent_id"), ) # Create compression service with DB update capability diff --git a/application/api/answer/services/stream_processor.py b/application/api/answer/services/stream_processor.py index 6f7ec9d3..588bb642 100644 --- a/application/api/answer/services/stream_processor.py +++ b/application/api/answer/services/stream_processor.py @@ -90,6 +90,7 @@ class StreamProcessor: self.retriever_config = {} self.is_shared_usage = False self.shared_token = None + self.agent_id = self.data.get("agent_id") self.model_id: Optional[str] = None self.conversation_service = ConversationService() self.compression_orchestrator = CompressionOrchestrator( @@ -355,10 +356,13 @@ class StreamProcessor: self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key( agent_id, self.initial_user_id ) + self.agent_id = str(agent_id) if agent_id else None api_key = self.data.get("api_key") if api_key: data_key = self._get_data_from_api_key(api_key) + if data_key.get("_id"): + self.agent_id = str(data_key.get("_id")) self.agent_config.update( { "prompt_id": data_key.get("prompt_id", "default"), @@ -387,6 +391,8 @@ class StreamProcessor: self.retriever_config["chunks"] = 2 elif self.agent_key: data_key = self._get_data_from_api_key(self.agent_key) + if data_key.get("_id"): + self.agent_id = str(data_key.get("_id")) self.agent_config.update( { "prompt_id": data_key.get("prompt_id", "default"), @@ -459,6 +465,7 @@ class StreamProcessor: doc_token_limit=self.retriever_config.get("doc_token_limit", 50000), model_id=self.model_id, user_api_key=self.agent_config["user_api_key"], + agent_id=self.agent_id, decoded_token=self.decoded_token, ) @@ -754,6 +761,7 @@ class StreamProcessor: "llm_name": provider or settings.LLM_PROVIDER, "model_id": self.model_id, "api_key": system_api_key, + "agent_id": self.agent_id, "user_api_key": self.agent_config["user_api_key"], "prompt": rendered_prompt, "chat_history": self.history, diff --git a/application/llm/base.py b/application/llm/base.py index e8ee78eb..9a0f249e 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -13,10 +13,12 @@ class BaseLLM(ABC): def __init__( self, decoded_token=None, + agent_id=None, model_id=None, base_url=None, ): self.decoded_token = decoded_token + self.agent_id = str(agent_id) if agent_id else None self.model_id = model_id self.base_url = base_url self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} @@ -33,9 +35,10 @@ class BaseLLM(ABC): self._fallback_llm = LLMCreator.create_llm( settings.FALLBACK_LLM_PROVIDER, api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY, - user_api_key=None, + user_api_key=getattr(self, "user_api_key", None), decoded_token=self.decoded_token, model_id=settings.FALLBACK_LLM_NAME, + agent_id=self.agent_id, ) logger.info( f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}" diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index bd262a87..36892309 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -13,7 +13,7 @@ class GoogleLLM(BaseLLM): def __init__( self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs ): - super().__init__(*args, **kwargs) + super().__init__(decoded_token=decoded_token, *args, **kwargs) self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY self.user_api_key = user_api_key diff --git a/application/llm/handlers/base.py b/application/llm/handlers/base.py index e33bd18e..7537d9c5 100644 --- a/application/llm/handlers/base.py +++ b/application/llm/handlers/base.py @@ -567,6 +567,7 @@ class LLMHandler(ABC): getattr(agent, "user_api_key", None), getattr(agent, "decoded_token", None), model_id=compression_model, + agent_id=getattr(agent, "agent_id", None), ) # Create service without DB persistence capability diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py index 96653831..9f444359 100644 --- a/application/llm/llm_creator.py +++ b/application/llm/llm_creator.py @@ -31,7 +31,15 @@ class LLMCreator: @classmethod def create_llm( - cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs + cls, + type, + api_key, + user_api_key, + decoded_token, + model_id=None, + agent_id=None, + *args, + **kwargs, ): from application.core.model_utils import get_base_url_for_model @@ -49,6 +57,7 @@ class LLMCreator: user_api_key, decoded_token=decoded_token, model_id=model_id, + agent_id=agent_id, base_url=base_url, *args, **kwargs, diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index ec3033bb..4bf8b731 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -18,6 +18,7 @@ class ClassicRAG(BaseRetriever): doc_token_limit=50000, model_id="docsgpt-local", user_api_key=None, + agent_id=None, llm_name=settings.LLM_PROVIDER, api_key=settings.API_KEY, decoded_token=None, @@ -43,6 +44,7 @@ class ClassicRAG(BaseRetriever): self.model_id = model_id self.doc_token_limit = doc_token_limit self.user_api_key = user_api_key + self.agent_id = agent_id self.llm_name = llm_name self.api_key = api_key self.llm = LLMCreator.create_llm( @@ -50,6 +52,7 @@ class ClassicRAG(BaseRetriever): api_key=self.api_key, user_api_key=self.user_api_key, decoded_token=decoded_token, + agent_id=self.agent_id, ) if "active_docs" in source and source["active_docs"] is not None: diff --git a/application/usage.py b/application/usage.py index 46620fff..e41a3588 100644 --- a/application/usage.py +++ b/application/usage.py @@ -1,22 +1,104 @@ import sys +import logging from datetime import datetime from application.core.mongo_db import MongoDB from application.core.settings import settings from application.utils import num_tokens_from_object_or_list, num_tokens_from_string +logger = logging.getLogger(__name__) + mongo = MongoDB.get_client() db = mongo[settings.MONGO_DB_NAME] usage_collection = db["token_usage"] -def update_token_usage(decoded_token, user_api_key, token_usage): +def _serialize_for_token_count(value): + """Normalize payloads into token-countable primitives.""" + if isinstance(value, str): + # Avoid counting large binary payloads in data URLs as text tokens. + if value.startswith("data:") and ";base64," in value: + return "" + return value + + if value is None: + return "" + + if isinstance(value, list): + return [_serialize_for_token_count(item) for item in value] + + if isinstance(value, dict): + serialized = {} + for key, raw in value.items(): + key_lower = str(key).lower() + + # Skip raw binary-like fields; keep textual tool-call fields. + if key_lower in {"data", "base64", "image_data"} and isinstance(raw, str): + continue + if key_lower == "url" and isinstance(raw, str) and ";base64," in raw: + continue + + serialized[key] = _serialize_for_token_count(raw) + return serialized + + if hasattr(value, "model_dump") and callable(getattr(value, "model_dump")): + return _serialize_for_token_count(value.model_dump()) + if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")): + return _serialize_for_token_count(value.to_dict()) + if hasattr(value, "__dict__"): + return _serialize_for_token_count(vars(value)) + + return str(value) + + +def _count_tokens(value): + serialized = _serialize_for_token_count(value) + if isinstance(serialized, str): + return num_tokens_from_string(serialized) + return num_tokens_from_object_or_list(serialized) + + +def _count_prompt_tokens(messages, tools=None, usage_attachments=None, **kwargs): + prompt_tokens = 0 + + for message in messages or []: + if not isinstance(message, dict): + prompt_tokens += _count_tokens(message) + continue + + prompt_tokens += _count_tokens(message.get("content")) + + # Include tool-related message fields for providers that use OpenAI-native format. + prompt_tokens += _count_tokens(message.get("tool_calls")) + prompt_tokens += _count_tokens(message.get("tool_call_id")) + prompt_tokens += _count_tokens(message.get("function_call")) + prompt_tokens += _count_tokens(message.get("function_response")) + + # Count tool schema payload passed to the model. + prompt_tokens += _count_tokens(tools) + + # Count structured-output/schema payloads when provided. + prompt_tokens += _count_tokens(kwargs.get("response_format")) + prompt_tokens += _count_tokens(kwargs.get("response_schema")) + + # Optional usage-only attachment context (not forwarded to provider). + prompt_tokens += _count_tokens(usage_attachments) + + return prompt_tokens + + +def update_token_usage(decoded_token, user_api_key, token_usage, agent_id=None): if "pytest" in sys.modules: return - if decoded_token: - user_id = decoded_token["sub"] - else: - user_id = None + user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None + normalized_agent_id = str(agent_id) if agent_id else None + + if not user_id and not user_api_key and not normalized_agent_id: + logger.warning( + "Skipping token usage insert: missing user_id, api_key, and agent_id" + ) + return + usage_data = { "user_id": user_id, "api_key": user_api_key, @@ -24,24 +106,31 @@ def update_token_usage(decoded_token, user_api_key, token_usage): "generated_tokens": token_usage["generated_tokens"], "timestamp": datetime.now(), } + if normalized_agent_id: + usage_data["agent_id"] = normalized_agent_id usage_collection.insert_one(usage_data) def gen_token_usage(func): def wrapper(self, model, messages, stream, tools, **kwargs): - for message in messages: - if message["content"]: - self.token_usage["prompt_tokens"] += num_tokens_from_string( - message["content"] - ) + usage_attachments = kwargs.pop("_usage_attachments", None) + call_usage = {"prompt_tokens": 0, "generated_tokens": 0} + call_usage["prompt_tokens"] += _count_prompt_tokens( + messages, + tools=tools, + usage_attachments=usage_attachments, + **kwargs, + ) result = func(self, model, messages, stream, tools, **kwargs) - if isinstance(result, str): - self.token_usage["generated_tokens"] += num_tokens_from_string(result) - else: - self.token_usage["generated_tokens"] += num_tokens_from_object_or_list( - result - ) - update_token_usage(self.decoded_token, self.user_api_key, self.token_usage) + call_usage["generated_tokens"] += _count_tokens(result) + self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"] + self.token_usage["generated_tokens"] += call_usage["generated_tokens"] + update_token_usage( + self.decoded_token, + self.user_api_key, + call_usage, + getattr(self, "agent_id", None), + ) return result return wrapper @@ -49,17 +138,28 @@ def gen_token_usage(func): def stream_token_usage(func): def wrapper(self, model, messages, stream, tools, **kwargs): - for message in messages: - self.token_usage["prompt_tokens"] += num_tokens_from_string( - message["content"] - ) + usage_attachments = kwargs.pop("_usage_attachments", None) + call_usage = {"prompt_tokens": 0, "generated_tokens": 0} + call_usage["prompt_tokens"] += _count_prompt_tokens( + messages, + tools=tools, + usage_attachments=usage_attachments, + **kwargs, + ) batch = [] result = func(self, model, messages, stream, tools, **kwargs) for r in result: batch.append(r) yield r for line in batch: - self.token_usage["generated_tokens"] += num_tokens_from_string(line) - update_token_usage(self.decoded_token, self.user_api_key, self.token_usage) + call_usage["generated_tokens"] += _count_tokens(line) + self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"] + self.token_usage["generated_tokens"] += call_usage["generated_tokens"] + update_token_usage( + self.decoded_token, + self.user_api_key, + call_usage, + getattr(self, "agent_id", None), + ) return wrapper diff --git a/application/worker.py b/application/worker.py index d746396c..e1f6a733 100755 --- a/application/worker.py +++ b/application/worker.py @@ -322,6 +322,7 @@ def run_agent_logic(agent_config, input_data): chunks = int(agent_config.get("chunks", 2)) prompt_id = agent_config.get("prompt_id", "default") user_api_key = agent_config["key"] + agent_id = str(agent_config.get("_id")) if agent_config.get("_id") else None agent_type = agent_config.get("agent_type", "classic") decoded_token = {"sub": agent_config.get("user")} json_schema = agent_config.get("json_schema") @@ -352,6 +353,7 @@ def run_agent_logic(agent_config, input_data): doc_token_limit=doc_token_limit, model_id=model_id, user_api_key=user_api_key, + agent_id=agent_id, decoded_token=decoded_token, ) @@ -370,6 +372,7 @@ def run_agent_logic(agent_config, input_data): llm_name=provider or settings.LLM_PROVIDER, model_id=model_id, api_key=system_api_key, + agent_id=agent_id, user_api_key=user_api_key, prompt=prompt, chat_history=[], diff --git a/tests/api/answer/services/test_stream_processor.py b/tests/api/answer/services/test_stream_processor.py index 727b3a73..45cbbb67 100644 --- a/tests/api/answer/services/test_stream_processor.py +++ b/tests/api/answer/services/test_stream_processor.py @@ -199,6 +199,7 @@ class TestStreamProcessorAgentConfiguration: try: processor._configure_agent() assert processor.agent_config is not None + assert processor.agent_id == str(agent_id) except Exception as e: assert "Invalid API Key" in str(e) @@ -211,6 +212,7 @@ class TestStreamProcessorAgentConfiguration: processor._configure_agent() assert isinstance(processor.agent_config, dict) + assert processor.agent_id is None @pytest.mark.unit diff --git a/tests/test_usage.py b/tests/test_usage.py new file mode 100644 index 00000000..10185fef --- /dev/null +++ b/tests/test_usage.py @@ -0,0 +1,326 @@ +import sys + +import pytest + +from application.usage import ( + _count_tokens, + gen_token_usage, + stream_token_usage, + update_token_usage, +) + + +@pytest.mark.unit +def test_count_tokens_includes_tool_call_payloads(): + payload = [ + { + "function_call": { + "name": "search_docs", + "args": {"query": "pricing limits"}, + "call_id": "call_1", + } + }, + { + "function_response": { + "name": "search_docs", + "response": {"result": "Found 3 docs"}, + "call_id": "call_1", + } + }, + ] + + assert _count_tokens(payload) > 0 + + +@pytest.mark.unit +def test_gen_token_usage_counts_structured_tool_content(monkeypatch): + captured = {} + + def fake_update(decoded_token, user_api_key, token_usage, agent_id=None): + captured["decoded_token"] = decoded_token + captured["user_api_key"] = user_api_key + captured["token_usage"] = token_usage.copy() + captured["agent_id"] = agent_id + + monkeypatch.setattr("application.usage.update_token_usage", fake_update) + + class DummyLLM: + decoded_token = {"sub": "user_123"} + user_api_key = "api_key_123" + agent_id = "agent_123" + token_usage = {"prompt_tokens": 0, "generated_tokens": 0} + + @gen_token_usage + def wrapped(self, model, messages, stream, tools, **kwargs): + _ = (model, messages, stream, tools, kwargs) + return { + "tool_calls": [ + {"name": "read_webpage", "arguments": {"url": "https://example.com"}} + ] + } + + messages = [ + { + "role": "assistant", + "content": [ + { + "function_call": { + "name": "search_docs", + "args": {"query": "pricing"}, + "call_id": "1", + } + } + ], + }, + { + "role": "tool", + "content": [ + { + "function_response": { + "name": "search_docs", + "response": {"result": "Found docs"}, + "call_id": "1", + } + } + ], + }, + ] + + llm = DummyLLM() + wrapped(llm, "gpt-4o", messages, False, None) + + assert captured["decoded_token"] == {"sub": "user_123"} + assert captured["user_api_key"] == "api_key_123" + assert captured["agent_id"] == "agent_123" + assert captured["token_usage"]["prompt_tokens"] > 0 + assert captured["token_usage"]["generated_tokens"] > 0 + + +@pytest.mark.unit +def test_stream_token_usage_counts_tool_call_chunks(monkeypatch): + captured = {} + + def fake_update(decoded_token, user_api_key, token_usage, agent_id=None): + captured["token_usage"] = token_usage.copy() + captured["agent_id"] = agent_id + + monkeypatch.setattr("application.usage.update_token_usage", fake_update) + + class ToolChunk: + def model_dump(self): + return { + "delta": { + "tool_calls": [ + { + "id": "call_1", + "function": { + "name": "get_weather", + "arguments": '{"location":"Seattle"}', + }, + } + ] + } + } + + class DummyLLM: + decoded_token = {"sub": "user_123"} + user_api_key = "api_key_123" + agent_id = "agent_123" + token_usage = {"prompt_tokens": 0, "generated_tokens": 0} + + @stream_token_usage + def wrapped(self, model, messages, stream, tools, **kwargs): + _ = (model, messages, stream, tools, kwargs) + yield ToolChunk() + yield "done" + + messages = [ + { + "role": "assistant", + "content": [ + { + "function_call": { + "name": "get_weather", + "args": {"location": "Seattle"}, + "call_id": "1", + } + } + ], + } + ] + + llm = DummyLLM() + list(wrapped(llm, "gpt-4o", messages, True, None)) + + assert captured["agent_id"] == "agent_123" + assert captured["token_usage"]["prompt_tokens"] > 0 + assert captured["token_usage"]["generated_tokens"] > 0 + + +@pytest.mark.unit +def test_gen_token_usage_counts_tools_and_image_inputs(monkeypatch): + captured = [] + + def fake_update(decoded_token, user_api_key, token_usage, agent_id=None): + _ = (decoded_token, user_api_key, agent_id) + captured.append(token_usage.copy()) + + monkeypatch.setattr("application.usage.update_token_usage", fake_update) + + class DummyLLM: + decoded_token = {"sub": "user_123"} + user_api_key = "api_key_123" + agent_id = "agent_123" + token_usage = {"prompt_tokens": 0, "generated_tokens": 0} + + @gen_token_usage + def wrapped(self, model, messages, stream, tools, **kwargs): + _ = (model, messages, stream, tools, kwargs) + return "ok" + + messages = [{"role": "user", "content": "What is in this image?"}] + tools_payload = [ + { + "type": "function", + "function": { + "name": "describe_image", + "description": "Describe image content", + "parameters": { + "type": "object", + "properties": {"detail": {"type": "string"}}, + }, + }, + } + ] + usage_attachments = [ + { + "mime_type": "image/png", + "path": "attachments/example.png", + "data": "abc123", + } + ] + + llm = DummyLLM() + wrapped(llm, "gpt-4o", messages, False, None) + wrapped( + llm, + "gpt-4o", + messages, + False, + tools_payload, + _usage_attachments=usage_attachments, + ) + + assert len(captured) == 2 + assert captured[1]["prompt_tokens"] > captured[0]["prompt_tokens"] + + +@pytest.mark.unit +def test_stream_token_usage_counts_tools_and_image_inputs(monkeypatch): + captured = [] + + def fake_update(decoded_token, user_api_key, token_usage, agent_id=None): + _ = (decoded_token, user_api_key, agent_id) + captured.append(token_usage.copy()) + + monkeypatch.setattr("application.usage.update_token_usage", fake_update) + + class DummyLLM: + decoded_token = {"sub": "user_123"} + user_api_key = "api_key_123" + agent_id = "agent_123" + token_usage = {"prompt_tokens": 0, "generated_tokens": 0} + + @stream_token_usage + def wrapped(self, model, messages, stream, tools, **kwargs): + _ = (model, messages, stream, tools, kwargs) + yield "ok" + + messages = [{"role": "user", "content": "What is in this image?"}] + tools_payload = [ + { + "type": "function", + "function": { + "name": "describe_image", + "description": "Describe image content", + "parameters": { + "type": "object", + "properties": {"detail": {"type": "string"}}, + }, + }, + } + ] + usage_attachments = [ + { + "mime_type": "image/png", + "path": "attachments/example.png", + "data": "abc123", + } + ] + + llm = DummyLLM() + list(wrapped(llm, "gpt-4o", messages, True, None)) + list( + wrapped( + llm, + "gpt-4o", + messages, + True, + tools_payload, + _usage_attachments=usage_attachments, + ) + ) + + assert len(captured) == 2 + assert captured[1]["prompt_tokens"] > captured[0]["prompt_tokens"] + + +@pytest.mark.unit +def test_update_token_usage_inserts_with_agent_id_only(monkeypatch): + inserted_docs = [] + + class FakeCollection: + def insert_one(self, doc): + inserted_docs.append(doc) + + modules_without_pytest = dict(sys.modules) + modules_without_pytest.pop("pytest", None) + + monkeypatch.setattr("application.usage.sys.modules", modules_without_pytest) + monkeypatch.setattr("application.usage.usage_collection", FakeCollection()) + + update_token_usage( + decoded_token=None, + user_api_key=None, + token_usage={"prompt_tokens": 10, "generated_tokens": 5}, + agent_id="agent_123", + ) + + assert len(inserted_docs) == 1 + assert inserted_docs[0]["agent_id"] == "agent_123" + assert inserted_docs[0]["user_id"] is None + assert inserted_docs[0]["api_key"] is None + + +@pytest.mark.unit +def test_update_token_usage_skips_when_all_ids_missing(monkeypatch): + inserted_docs = [] + + class FakeCollection: + def insert_one(self, doc): + inserted_docs.append(doc) + + modules_without_pytest = dict(sys.modules) + modules_without_pytest.pop("pytest", None) + + monkeypatch.setattr("application.usage.sys.modules", modules_without_pytest) + monkeypatch.setattr("application.usage.usage_collection", FakeCollection()) + + update_token_usage( + decoded_token=None, + user_api_key=None, + token_usage={"prompt_tokens": 10, "generated_tokens": 5}, + agent_id=None, + ) + + assert inserted_docs == []