mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-03-02 16:21:50 +00:00
fix: token calc (#2285)
This commit is contained in:
@@ -23,6 +23,7 @@ class BaseAgent(ABC):
|
|||||||
llm_name: str,
|
llm_name: str,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
|
agent_id: Optional[str] = None,
|
||||||
user_api_key: Optional[str] = None,
|
user_api_key: Optional[str] = None,
|
||||||
prompt: str = "",
|
prompt: str = "",
|
||||||
chat_history: Optional[List[Dict]] = None,
|
chat_history: Optional[List[Dict]] = None,
|
||||||
@@ -40,6 +41,7 @@ class BaseAgent(ABC):
|
|||||||
self.llm_name = llm_name
|
self.llm_name = llm_name
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
self.agent_id = agent_id
|
||||||
self.user_api_key = user_api_key
|
self.user_api_key = user_api_key
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.decoded_token = decoded_token or {}
|
self.decoded_token = decoded_token or {}
|
||||||
@@ -54,6 +56,7 @@ class BaseAgent(ABC):
|
|||||||
user_api_key=user_api_key,
|
user_api_key=user_api_key,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
agent_id=agent_id,
|
||||||
)
|
)
|
||||||
self.retrieved_docs = retrieved_docs or []
|
self.retrieved_docs = retrieved_docs or []
|
||||||
self.llm_handler = LLMHandlerCreator.create_handler(
|
self.llm_handler = LLMHandlerCreator.create_handler(
|
||||||
@@ -263,6 +266,11 @@ class BaseAgent(ABC):
|
|||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
user_id=self.user,
|
user_id=self.user,
|
||||||
)
|
)
|
||||||
|
resolved_arguments = (
|
||||||
|
{"query_params": query_params, "headers": headers, "body": body}
|
||||||
|
if tool_data["name"] == "api_tool"
|
||||||
|
else parameters
|
||||||
|
)
|
||||||
if tool_data["name"] == "api_tool":
|
if tool_data["name"] == "api_tool":
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||||
@@ -292,11 +300,19 @@ class BaseAgent(ABC):
|
|||||||
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
|
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
|
||||||
if artifact_id:
|
if artifact_id:
|
||||||
tool_call_data["artifact_id"] = artifact_id
|
tool_call_data["artifact_id"] = artifact_id
|
||||||
|
result_full = str(result)
|
||||||
|
tool_call_data["resolved_arguments"] = resolved_arguments
|
||||||
|
tool_call_data["result_full"] = result_full
|
||||||
tool_call_data["result"] = (
|
tool_call_data["result"] = (
|
||||||
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
|
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
|
||||||
)
|
)
|
||||||
|
|
||||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
|
stream_tool_call_data = {
|
||||||
|
key: value
|
||||||
|
for key, value in tool_call_data.items()
|
||||||
|
if key not in {"result_full", "resolved_arguments"}
|
||||||
|
}
|
||||||
|
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
|
||||||
self.tool_calls.append(tool_call_data)
|
self.tool_calls.append(tool_call_data)
|
||||||
|
|
||||||
return result, call_id
|
return result, call_id
|
||||||
@@ -304,7 +320,11 @@ class BaseAgent(ABC):
|
|||||||
def _get_truncated_tool_calls(self):
|
def _get_truncated_tool_calls(self):
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
**tool_call,
|
"tool_name": tool_call.get("tool_name"),
|
||||||
|
"call_id": tool_call.get("call_id"),
|
||||||
|
"action_name": tool_call.get("action_name"),
|
||||||
|
"arguments": tool_call.get("arguments"),
|
||||||
|
"artifact_id": tool_call.get("artifact_id"),
|
||||||
"result": (
|
"result": (
|
||||||
f"{str(tool_call['result'])[:50]}..."
|
f"{str(tool_call['result'])[:50]}..."
|
||||||
if len(str(tool_call["result"])) > 50
|
if len(str(tool_call["result"])) > 50
|
||||||
@@ -576,6 +596,9 @@ class BaseAgent(ABC):
|
|||||||
self._validate_context_size(messages)
|
self._validate_context_size(messages)
|
||||||
|
|
||||||
gen_kwargs = {"model": self.model_id, "messages": messages}
|
gen_kwargs = {"model": self.model_id, "messages": messages}
|
||||||
|
if self.attachments:
|
||||||
|
# Usage accounting only; stripped before provider invocation.
|
||||||
|
gen_kwargs["_usage_attachments"] = self.attachments
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(self.llm, "_supports_tools")
|
hasattr(self.llm, "_supports_tools")
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class AnswerResource(Resource, BaseAnswerResource):
|
|||||||
),
|
),
|
||||||
"retriever": fields.String(required=False, description="Retriever type"),
|
"retriever": fields.String(required=False, description="Retriever type"),
|
||||||
"api_key": fields.String(required=False, description="API key"),
|
"api_key": fields.String(required=False, description="API key"),
|
||||||
|
"agent_id": fields.String(required=False, description="Agent ID"),
|
||||||
"active_docs": fields.String(
|
"active_docs": fields.String(
|
||||||
required=False, description="Active documents"
|
required=False, description="Active documents"
|
||||||
),
|
),
|
||||||
@@ -100,6 +101,9 @@ class AnswerResource(Resource, BaseAnswerResource):
|
|||||||
isNoneDoc=data.get("isNoneDoc"),
|
isNoneDoc=data.get("isNoneDoc"),
|
||||||
index=None,
|
index=None,
|
||||||
should_save_conversation=data.get("save_conversation", True),
|
should_save_conversation=data.get("save_conversation", True),
|
||||||
|
agent_id=processor.agent_id,
|
||||||
|
is_shared_usage=processor.is_shared_usage,
|
||||||
|
shared_token=processor.shared_token,
|
||||||
model_id=processor.model_id,
|
model_id=processor.model_id,
|
||||||
)
|
)
|
||||||
stream_result = self.process_response_stream(stream)
|
stream_result = self.process_response_stream(stream)
|
||||||
|
|||||||
@@ -46,6 +46,27 @@ class BaseAnswerResource:
|
|||||||
return missing_fields
|
return missing_fields
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_tool_calls_for_logging(
|
||||||
|
tool_calls: Optional[List[Dict[str, Any]]], max_chars: int = 10000
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
if not tool_calls:
|
||||||
|
return []
|
||||||
|
|
||||||
|
prepared = []
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
if not isinstance(tool_call, dict):
|
||||||
|
prepared.append({"result": str(tool_call)[:max_chars]})
|
||||||
|
continue
|
||||||
|
|
||||||
|
item = dict(tool_call)
|
||||||
|
for key in ("result", "result_full"):
|
||||||
|
value = item.get(key)
|
||||||
|
if isinstance(value, str) and len(value) > max_chars:
|
||||||
|
item[key] = value[:max_chars]
|
||||||
|
prepared.append(item)
|
||||||
|
return prepared
|
||||||
|
|
||||||
def check_usage(self, agent_config: Dict) -> Optional[Response]:
|
def check_usage(self, agent_config: Dict) -> Optional[Response]:
|
||||||
"""Check if there is a usage limit and if it is exceeded
|
"""Check if there is a usage limit and if it is exceeded
|
||||||
|
|
||||||
@@ -246,6 +267,7 @@ class BaseAnswerResource:
|
|||||||
user_api_key=user_api_key,
|
user_api_key=user_api_key,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
agent_id=agent_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if should_save_conversation:
|
if should_save_conversation:
|
||||||
@@ -292,14 +314,20 @@ class BaseAnswerResource:
|
|||||||
data = json.dumps(id_data)
|
data = json.dumps(id_data)
|
||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
|
||||||
|
getattr(agent, "tool_calls", tool_calls) or tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
log_data = {
|
log_data = {
|
||||||
"action": "stream_answer",
|
"action": "stream_answer",
|
||||||
"level": "info",
|
"level": "info",
|
||||||
"user": decoded_token.get("sub"),
|
"user": decoded_token.get("sub"),
|
||||||
"api_key": user_api_key,
|
"api_key": user_api_key,
|
||||||
|
"agent_id": agent_id,
|
||||||
"question": question,
|
"question": question,
|
||||||
"response": response_full,
|
"response": response_full,
|
||||||
"sources": source_log_docs,
|
"sources": source_log_docs,
|
||||||
|
"tool_calls": tool_calls_for_logging,
|
||||||
"attachments": attachment_ids,
|
"attachments": attachment_ids,
|
||||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||||
}
|
}
|
||||||
@@ -330,6 +358,7 @@ class BaseAnswerResource:
|
|||||||
api_key=settings.API_KEY,
|
api_key=settings.API_KEY,
|
||||||
user_api_key=user_api_key,
|
user_api_key=user_api_key,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
|
agent_id=agent_id,
|
||||||
)
|
)
|
||||||
self.conversation_service.save_conversation(
|
self.conversation_service.save_conversation(
|
||||||
conversation_id,
|
conversation_id,
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
|||||||
),
|
),
|
||||||
"retriever": fields.String(required=False, description="Retriever type"),
|
"retriever": fields.String(required=False, description="Retriever type"),
|
||||||
"api_key": fields.String(required=False, description="API key"),
|
"api_key": fields.String(required=False, description="API key"),
|
||||||
|
"agent_id": fields.String(required=False, description="Agent ID"),
|
||||||
"active_docs": fields.String(
|
"active_docs": fields.String(
|
||||||
required=False, description="Active documents"
|
required=False, description="Active documents"
|
||||||
),
|
),
|
||||||
@@ -107,7 +108,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
|||||||
index=data.get("index"),
|
index=data.get("index"),
|
||||||
should_save_conversation=data.get("save_conversation", True),
|
should_save_conversation=data.get("save_conversation", True),
|
||||||
attachment_ids=data.get("attachments", []),
|
attachment_ids=data.get("attachments", []),
|
||||||
agent_id=data.get("agent_id"),
|
agent_id=processor.agent_id,
|
||||||
is_shared_usage=processor.is_shared_usage,
|
is_shared_usage=processor.is_shared_usage,
|
||||||
shared_token=processor.shared_token,
|
shared_token=processor.shared_token,
|
||||||
model_id=processor.model_id,
|
model_id=processor.model_id,
|
||||||
|
|||||||
@@ -134,6 +134,7 @@ class CompressionOrchestrator:
|
|||||||
user_api_key=None,
|
user_api_key=None,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
model_id=compression_model,
|
model_id=compression_model,
|
||||||
|
agent_id=conversation.get("agent_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create compression service with DB update capability
|
# Create compression service with DB update capability
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ class StreamProcessor:
|
|||||||
self.retriever_config = {}
|
self.retriever_config = {}
|
||||||
self.is_shared_usage = False
|
self.is_shared_usage = False
|
||||||
self.shared_token = None
|
self.shared_token = None
|
||||||
|
self.agent_id = self.data.get("agent_id")
|
||||||
self.model_id: Optional[str] = None
|
self.model_id: Optional[str] = None
|
||||||
self.conversation_service = ConversationService()
|
self.conversation_service = ConversationService()
|
||||||
self.compression_orchestrator = CompressionOrchestrator(
|
self.compression_orchestrator = CompressionOrchestrator(
|
||||||
@@ -355,10 +356,13 @@ class StreamProcessor:
|
|||||||
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
|
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
|
||||||
agent_id, self.initial_user_id
|
agent_id, self.initial_user_id
|
||||||
)
|
)
|
||||||
|
self.agent_id = str(agent_id) if agent_id else None
|
||||||
|
|
||||||
api_key = self.data.get("api_key")
|
api_key = self.data.get("api_key")
|
||||||
if api_key:
|
if api_key:
|
||||||
data_key = self._get_data_from_api_key(api_key)
|
data_key = self._get_data_from_api_key(api_key)
|
||||||
|
if data_key.get("_id"):
|
||||||
|
self.agent_id = str(data_key.get("_id"))
|
||||||
self.agent_config.update(
|
self.agent_config.update(
|
||||||
{
|
{
|
||||||
"prompt_id": data_key.get("prompt_id", "default"),
|
"prompt_id": data_key.get("prompt_id", "default"),
|
||||||
@@ -387,6 +391,8 @@ class StreamProcessor:
|
|||||||
self.retriever_config["chunks"] = 2
|
self.retriever_config["chunks"] = 2
|
||||||
elif self.agent_key:
|
elif self.agent_key:
|
||||||
data_key = self._get_data_from_api_key(self.agent_key)
|
data_key = self._get_data_from_api_key(self.agent_key)
|
||||||
|
if data_key.get("_id"):
|
||||||
|
self.agent_id = str(data_key.get("_id"))
|
||||||
self.agent_config.update(
|
self.agent_config.update(
|
||||||
{
|
{
|
||||||
"prompt_id": data_key.get("prompt_id", "default"),
|
"prompt_id": data_key.get("prompt_id", "default"),
|
||||||
@@ -459,6 +465,7 @@ class StreamProcessor:
|
|||||||
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||||||
model_id=self.model_id,
|
model_id=self.model_id,
|
||||||
user_api_key=self.agent_config["user_api_key"],
|
user_api_key=self.agent_config["user_api_key"],
|
||||||
|
agent_id=self.agent_id,
|
||||||
decoded_token=self.decoded_token,
|
decoded_token=self.decoded_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -754,6 +761,7 @@ class StreamProcessor:
|
|||||||
"llm_name": provider or settings.LLM_PROVIDER,
|
"llm_name": provider or settings.LLM_PROVIDER,
|
||||||
"model_id": self.model_id,
|
"model_id": self.model_id,
|
||||||
"api_key": system_api_key,
|
"api_key": system_api_key,
|
||||||
|
"agent_id": self.agent_id,
|
||||||
"user_api_key": self.agent_config["user_api_key"],
|
"user_api_key": self.agent_config["user_api_key"],
|
||||||
"prompt": rendered_prompt,
|
"prompt": rendered_prompt,
|
||||||
"chat_history": self.history,
|
"chat_history": self.history,
|
||||||
|
|||||||
@@ -13,10 +13,12 @@ class BaseLLM(ABC):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
decoded_token=None,
|
decoded_token=None,
|
||||||
|
agent_id=None,
|
||||||
model_id=None,
|
model_id=None,
|
||||||
base_url=None,
|
base_url=None,
|
||||||
):
|
):
|
||||||
self.decoded_token = decoded_token
|
self.decoded_token = decoded_token
|
||||||
|
self.agent_id = str(agent_id) if agent_id else None
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||||
@@ -33,9 +35,10 @@ class BaseLLM(ABC):
|
|||||||
self._fallback_llm = LLMCreator.create_llm(
|
self._fallback_llm = LLMCreator.create_llm(
|
||||||
settings.FALLBACK_LLM_PROVIDER,
|
settings.FALLBACK_LLM_PROVIDER,
|
||||||
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
|
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
|
||||||
user_api_key=None,
|
user_api_key=getattr(self, "user_api_key", None),
|
||||||
decoded_token=self.decoded_token,
|
decoded_token=self.decoded_token,
|
||||||
model_id=settings.FALLBACK_LLM_NAME,
|
model_id=settings.FALLBACK_LLM_NAME,
|
||||||
|
agent_id=self.agent_id,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
|
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class GoogleLLM(BaseLLM):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
|
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(decoded_token=decoded_token, *args, **kwargs)
|
||||||
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
|
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
|
||||||
self.user_api_key = user_api_key
|
self.user_api_key = user_api_key
|
||||||
|
|
||||||
|
|||||||
@@ -567,6 +567,7 @@ class LLMHandler(ABC):
|
|||||||
getattr(agent, "user_api_key", None),
|
getattr(agent, "user_api_key", None),
|
||||||
getattr(agent, "decoded_token", None),
|
getattr(agent, "decoded_token", None),
|
||||||
model_id=compression_model,
|
model_id=compression_model,
|
||||||
|
agent_id=getattr(agent, "agent_id", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create service without DB persistence capability
|
# Create service without DB persistence capability
|
||||||
|
|||||||
@@ -31,7 +31,15 @@ class LLMCreator:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_llm(
|
def create_llm(
|
||||||
cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs
|
cls,
|
||||||
|
type,
|
||||||
|
api_key,
|
||||||
|
user_api_key,
|
||||||
|
decoded_token,
|
||||||
|
model_id=None,
|
||||||
|
agent_id=None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
from application.core.model_utils import get_base_url_for_model
|
from application.core.model_utils import get_base_url_for_model
|
||||||
|
|
||||||
@@ -49,6 +57,7 @@ class LLMCreator:
|
|||||||
user_api_key,
|
user_api_key,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
agent_id=agent_id,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class ClassicRAG(BaseRetriever):
|
|||||||
doc_token_limit=50000,
|
doc_token_limit=50000,
|
||||||
model_id="docsgpt-local",
|
model_id="docsgpt-local",
|
||||||
user_api_key=None,
|
user_api_key=None,
|
||||||
|
agent_id=None,
|
||||||
llm_name=settings.LLM_PROVIDER,
|
llm_name=settings.LLM_PROVIDER,
|
||||||
api_key=settings.API_KEY,
|
api_key=settings.API_KEY,
|
||||||
decoded_token=None,
|
decoded_token=None,
|
||||||
@@ -43,6 +44,7 @@ class ClassicRAG(BaseRetriever):
|
|||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.doc_token_limit = doc_token_limit
|
self.doc_token_limit = doc_token_limit
|
||||||
self.user_api_key = user_api_key
|
self.user_api_key = user_api_key
|
||||||
|
self.agent_id = agent_id
|
||||||
self.llm_name = llm_name
|
self.llm_name = llm_name
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.llm = LLMCreator.create_llm(
|
self.llm = LLMCreator.create_llm(
|
||||||
@@ -50,6 +52,7 @@ class ClassicRAG(BaseRetriever):
|
|||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
user_api_key=self.user_api_key,
|
user_api_key=self.user_api_key,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
|
agent_id=self.agent_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "active_docs" in source and source["active_docs"] is not None:
|
if "active_docs" in source and source["active_docs"] is not None:
|
||||||
|
|||||||
@@ -1,22 +1,104 @@
|
|||||||
import sys
|
import sys
|
||||||
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.utils import num_tokens_from_object_or_list, num_tokens_from_string
|
from application.utils import num_tokens_from_object_or_list, num_tokens_from_string
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
mongo = MongoDB.get_client()
|
mongo = MongoDB.get_client()
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
usage_collection = db["token_usage"]
|
usage_collection = db["token_usage"]
|
||||||
|
|
||||||
|
|
||||||
def update_token_usage(decoded_token, user_api_key, token_usage):
|
def _serialize_for_token_count(value):
|
||||||
|
"""Normalize payloads into token-countable primitives."""
|
||||||
|
if isinstance(value, str):
|
||||||
|
# Avoid counting large binary payloads in data URLs as text tokens.
|
||||||
|
if value.startswith("data:") and ";base64," in value:
|
||||||
|
return ""
|
||||||
|
return value
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [_serialize_for_token_count(item) for item in value]
|
||||||
|
|
||||||
|
if isinstance(value, dict):
|
||||||
|
serialized = {}
|
||||||
|
for key, raw in value.items():
|
||||||
|
key_lower = str(key).lower()
|
||||||
|
|
||||||
|
# Skip raw binary-like fields; keep textual tool-call fields.
|
||||||
|
if key_lower in {"data", "base64", "image_data"} and isinstance(raw, str):
|
||||||
|
continue
|
||||||
|
if key_lower == "url" and isinstance(raw, str) and ";base64," in raw:
|
||||||
|
continue
|
||||||
|
|
||||||
|
serialized[key] = _serialize_for_token_count(raw)
|
||||||
|
return serialized
|
||||||
|
|
||||||
|
if hasattr(value, "model_dump") and callable(getattr(value, "model_dump")):
|
||||||
|
return _serialize_for_token_count(value.model_dump())
|
||||||
|
if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")):
|
||||||
|
return _serialize_for_token_count(value.to_dict())
|
||||||
|
if hasattr(value, "__dict__"):
|
||||||
|
return _serialize_for_token_count(vars(value))
|
||||||
|
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_tokens(value):
|
||||||
|
serialized = _serialize_for_token_count(value)
|
||||||
|
if isinstance(serialized, str):
|
||||||
|
return num_tokens_from_string(serialized)
|
||||||
|
return num_tokens_from_object_or_list(serialized)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_prompt_tokens(messages, tools=None, usage_attachments=None, **kwargs):
|
||||||
|
prompt_tokens = 0
|
||||||
|
|
||||||
|
for message in messages or []:
|
||||||
|
if not isinstance(message, dict):
|
||||||
|
prompt_tokens += _count_tokens(message)
|
||||||
|
continue
|
||||||
|
|
||||||
|
prompt_tokens += _count_tokens(message.get("content"))
|
||||||
|
|
||||||
|
# Include tool-related message fields for providers that use OpenAI-native format.
|
||||||
|
prompt_tokens += _count_tokens(message.get("tool_calls"))
|
||||||
|
prompt_tokens += _count_tokens(message.get("tool_call_id"))
|
||||||
|
prompt_tokens += _count_tokens(message.get("function_call"))
|
||||||
|
prompt_tokens += _count_tokens(message.get("function_response"))
|
||||||
|
|
||||||
|
# Count tool schema payload passed to the model.
|
||||||
|
prompt_tokens += _count_tokens(tools)
|
||||||
|
|
||||||
|
# Count structured-output/schema payloads when provided.
|
||||||
|
prompt_tokens += _count_tokens(kwargs.get("response_format"))
|
||||||
|
prompt_tokens += _count_tokens(kwargs.get("response_schema"))
|
||||||
|
|
||||||
|
# Optional usage-only attachment context (not forwarded to provider).
|
||||||
|
prompt_tokens += _count_tokens(usage_attachments)
|
||||||
|
|
||||||
|
return prompt_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def update_token_usage(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||||
if "pytest" in sys.modules:
|
if "pytest" in sys.modules:
|
||||||
return
|
return
|
||||||
if decoded_token:
|
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
|
||||||
user_id = decoded_token["sub"]
|
normalized_agent_id = str(agent_id) if agent_id else None
|
||||||
else:
|
|
||||||
user_id = None
|
if not user_id and not user_api_key and not normalized_agent_id:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping token usage insert: missing user_id, api_key, and agent_id"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
usage_data = {
|
usage_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"api_key": user_api_key,
|
"api_key": user_api_key,
|
||||||
@@ -24,24 +106,31 @@ def update_token_usage(decoded_token, user_api_key, token_usage):
|
|||||||
"generated_tokens": token_usage["generated_tokens"],
|
"generated_tokens": token_usage["generated_tokens"],
|
||||||
"timestamp": datetime.now(),
|
"timestamp": datetime.now(),
|
||||||
}
|
}
|
||||||
|
if normalized_agent_id:
|
||||||
|
usage_data["agent_id"] = normalized_agent_id
|
||||||
usage_collection.insert_one(usage_data)
|
usage_collection.insert_one(usage_data)
|
||||||
|
|
||||||
|
|
||||||
def gen_token_usage(func):
|
def gen_token_usage(func):
|
||||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||||
for message in messages:
|
usage_attachments = kwargs.pop("_usage_attachments", None)
|
||||||
if message["content"]:
|
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(
|
call_usage["prompt_tokens"] += _count_prompt_tokens(
|
||||||
message["content"]
|
messages,
|
||||||
)
|
tools=tools,
|
||||||
|
usage_attachments=usage_attachments,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
result = func(self, model, messages, stream, tools, **kwargs)
|
result = func(self, model, messages, stream, tools, **kwargs)
|
||||||
if isinstance(result, str):
|
call_usage["generated_tokens"] += _count_tokens(result)
|
||||||
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
|
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||||
else:
|
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||||
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(
|
update_token_usage(
|
||||||
result
|
self.decoded_token,
|
||||||
)
|
self.user_api_key,
|
||||||
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
|
call_usage,
|
||||||
|
getattr(self, "agent_id", None),
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
@@ -49,17 +138,28 @@ def gen_token_usage(func):
|
|||||||
|
|
||||||
def stream_token_usage(func):
|
def stream_token_usage(func):
|
||||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||||
for message in messages:
|
usage_attachments = kwargs.pop("_usage_attachments", None)
|
||||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(
|
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||||
message["content"]
|
call_usage["prompt_tokens"] += _count_prompt_tokens(
|
||||||
)
|
messages,
|
||||||
|
tools=tools,
|
||||||
|
usage_attachments=usage_attachments,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
batch = []
|
batch = []
|
||||||
result = func(self, model, messages, stream, tools, **kwargs)
|
result = func(self, model, messages, stream, tools, **kwargs)
|
||||||
for r in result:
|
for r in result:
|
||||||
batch.append(r)
|
batch.append(r)
|
||||||
yield r
|
yield r
|
||||||
for line in batch:
|
for line in batch:
|
||||||
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
|
call_usage["generated_tokens"] += _count_tokens(line)
|
||||||
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
|
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||||
|
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||||
|
update_token_usage(
|
||||||
|
self.decoded_token,
|
||||||
|
self.user_api_key,
|
||||||
|
call_usage,
|
||||||
|
getattr(self, "agent_id", None),
|
||||||
|
)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
@@ -322,6 +322,7 @@ def run_agent_logic(agent_config, input_data):
|
|||||||
chunks = int(agent_config.get("chunks", 2))
|
chunks = int(agent_config.get("chunks", 2))
|
||||||
prompt_id = agent_config.get("prompt_id", "default")
|
prompt_id = agent_config.get("prompt_id", "default")
|
||||||
user_api_key = agent_config["key"]
|
user_api_key = agent_config["key"]
|
||||||
|
agent_id = str(agent_config.get("_id")) if agent_config.get("_id") else None
|
||||||
agent_type = agent_config.get("agent_type", "classic")
|
agent_type = agent_config.get("agent_type", "classic")
|
||||||
decoded_token = {"sub": agent_config.get("user")}
|
decoded_token = {"sub": agent_config.get("user")}
|
||||||
json_schema = agent_config.get("json_schema")
|
json_schema = agent_config.get("json_schema")
|
||||||
@@ -352,6 +353,7 @@ def run_agent_logic(agent_config, input_data):
|
|||||||
doc_token_limit=doc_token_limit,
|
doc_token_limit=doc_token_limit,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
user_api_key=user_api_key,
|
user_api_key=user_api_key,
|
||||||
|
agent_id=agent_id,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -370,6 +372,7 @@ def run_agent_logic(agent_config, input_data):
|
|||||||
llm_name=provider or settings.LLM_PROVIDER,
|
llm_name=provider or settings.LLM_PROVIDER,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
api_key=system_api_key,
|
api_key=system_api_key,
|
||||||
|
agent_id=agent_id,
|
||||||
user_api_key=user_api_key,
|
user_api_key=user_api_key,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
chat_history=[],
|
chat_history=[],
|
||||||
|
|||||||
@@ -199,6 +199,7 @@ class TestStreamProcessorAgentConfiguration:
|
|||||||
try:
|
try:
|
||||||
processor._configure_agent()
|
processor._configure_agent()
|
||||||
assert processor.agent_config is not None
|
assert processor.agent_config is not None
|
||||||
|
assert processor.agent_id == str(agent_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
assert "Invalid API Key" in str(e)
|
assert "Invalid API Key" in str(e)
|
||||||
|
|
||||||
@@ -211,6 +212,7 @@ class TestStreamProcessorAgentConfiguration:
|
|||||||
processor._configure_agent()
|
processor._configure_agent()
|
||||||
|
|
||||||
assert isinstance(processor.agent_config, dict)
|
assert isinstance(processor.agent_config, dict)
|
||||||
|
assert processor.agent_id is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
|
|||||||
326
tests/test_usage.py
Normal file
326
tests/test_usage.py
Normal file
@@ -0,0 +1,326 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from application.usage import (
|
||||||
|
_count_tokens,
|
||||||
|
gen_token_usage,
|
||||||
|
stream_token_usage,
|
||||||
|
update_token_usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_count_tokens_includes_tool_call_payloads():
|
||||||
|
payload = [
|
||||||
|
{
|
||||||
|
"function_call": {
|
||||||
|
"name": "search_docs",
|
||||||
|
"args": {"query": "pricing limits"},
|
||||||
|
"call_id": "call_1",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"function_response": {
|
||||||
|
"name": "search_docs",
|
||||||
|
"response": {"result": "Found 3 docs"},
|
||||||
|
"call_id": "call_1",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
assert _count_tokens(payload) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_gen_token_usage_counts_structured_tool_content(monkeypatch):
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||||
|
captured["decoded_token"] = decoded_token
|
||||||
|
captured["user_api_key"] = user_api_key
|
||||||
|
captured["token_usage"] = token_usage.copy()
|
||||||
|
captured["agent_id"] = agent_id
|
||||||
|
|
||||||
|
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
|
||||||
|
|
||||||
|
class DummyLLM:
|
||||||
|
decoded_token = {"sub": "user_123"}
|
||||||
|
user_api_key = "api_key_123"
|
||||||
|
agent_id = "agent_123"
|
||||||
|
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||||
|
|
||||||
|
@gen_token_usage
|
||||||
|
def wrapped(self, model, messages, stream, tools, **kwargs):
|
||||||
|
_ = (model, messages, stream, tools, kwargs)
|
||||||
|
return {
|
||||||
|
"tool_calls": [
|
||||||
|
{"name": "read_webpage", "arguments": {"url": "https://example.com"}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"function_call": {
|
||||||
|
"name": "search_docs",
|
||||||
|
"args": {"query": "pricing"},
|
||||||
|
"call_id": "1",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"function_response": {
|
||||||
|
"name": "search_docs",
|
||||||
|
"response": {"result": "Found docs"},
|
||||||
|
"call_id": "1",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
llm = DummyLLM()
|
||||||
|
wrapped(llm, "gpt-4o", messages, False, None)
|
||||||
|
|
||||||
|
assert captured["decoded_token"] == {"sub": "user_123"}
|
||||||
|
assert captured["user_api_key"] == "api_key_123"
|
||||||
|
assert captured["agent_id"] == "agent_123"
|
||||||
|
assert captured["token_usage"]["prompt_tokens"] > 0
|
||||||
|
assert captured["token_usage"]["generated_tokens"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_stream_token_usage_counts_tool_call_chunks(monkeypatch):
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||||
|
captured["token_usage"] = token_usage.copy()
|
||||||
|
captured["agent_id"] = agent_id
|
||||||
|
|
||||||
|
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
|
||||||
|
|
||||||
|
class ToolChunk:
|
||||||
|
def model_dump(self):
|
||||||
|
return {
|
||||||
|
"delta": {
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_1",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"location":"Seattle"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class DummyLLM:
|
||||||
|
decoded_token = {"sub": "user_123"}
|
||||||
|
user_api_key = "api_key_123"
|
||||||
|
agent_id = "agent_123"
|
||||||
|
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||||
|
|
||||||
|
@stream_token_usage
|
||||||
|
def wrapped(self, model, messages, stream, tools, **kwargs):
|
||||||
|
_ = (model, messages, stream, tools, kwargs)
|
||||||
|
yield ToolChunk()
|
||||||
|
yield "done"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"function_call": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"args": {"location": "Seattle"},
|
||||||
|
"call_id": "1",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
llm = DummyLLM()
|
||||||
|
list(wrapped(llm, "gpt-4o", messages, True, None))
|
||||||
|
|
||||||
|
assert captured["agent_id"] == "agent_123"
|
||||||
|
assert captured["token_usage"]["prompt_tokens"] > 0
|
||||||
|
assert captured["token_usage"]["generated_tokens"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_gen_token_usage_counts_tools_and_image_inputs(monkeypatch):
|
||||||
|
captured = []
|
||||||
|
|
||||||
|
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||||
|
_ = (decoded_token, user_api_key, agent_id)
|
||||||
|
captured.append(token_usage.copy())
|
||||||
|
|
||||||
|
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
|
||||||
|
|
||||||
|
class DummyLLM:
|
||||||
|
decoded_token = {"sub": "user_123"}
|
||||||
|
user_api_key = "api_key_123"
|
||||||
|
agent_id = "agent_123"
|
||||||
|
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||||
|
|
||||||
|
@gen_token_usage
|
||||||
|
def wrapped(self, model, messages, stream, tools, **kwargs):
|
||||||
|
_ = (model, messages, stream, tools, kwargs)
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "What is in this image?"}]
|
||||||
|
tools_payload = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "describe_image",
|
||||||
|
"description": "Describe image content",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"detail": {"type": "string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
usage_attachments = [
|
||||||
|
{
|
||||||
|
"mime_type": "image/png",
|
||||||
|
"path": "attachments/example.png",
|
||||||
|
"data": "abc123",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
llm = DummyLLM()
|
||||||
|
wrapped(llm, "gpt-4o", messages, False, None)
|
||||||
|
wrapped(
|
||||||
|
llm,
|
||||||
|
"gpt-4o",
|
||||||
|
messages,
|
||||||
|
False,
|
||||||
|
tools_payload,
|
||||||
|
_usage_attachments=usage_attachments,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(captured) == 2
|
||||||
|
assert captured[1]["prompt_tokens"] > captured[0]["prompt_tokens"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_stream_token_usage_counts_tools_and_image_inputs(monkeypatch):
|
||||||
|
captured = []
|
||||||
|
|
||||||
|
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||||
|
_ = (decoded_token, user_api_key, agent_id)
|
||||||
|
captured.append(token_usage.copy())
|
||||||
|
|
||||||
|
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
|
||||||
|
|
||||||
|
class DummyLLM:
|
||||||
|
decoded_token = {"sub": "user_123"}
|
||||||
|
user_api_key = "api_key_123"
|
||||||
|
agent_id = "agent_123"
|
||||||
|
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||||
|
|
||||||
|
@stream_token_usage
|
||||||
|
def wrapped(self, model, messages, stream, tools, **kwargs):
|
||||||
|
_ = (model, messages, stream, tools, kwargs)
|
||||||
|
yield "ok"
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "What is in this image?"}]
|
||||||
|
tools_payload = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "describe_image",
|
||||||
|
"description": "Describe image content",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"detail": {"type": "string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
usage_attachments = [
|
||||||
|
{
|
||||||
|
"mime_type": "image/png",
|
||||||
|
"path": "attachments/example.png",
|
||||||
|
"data": "abc123",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
llm = DummyLLM()
|
||||||
|
list(wrapped(llm, "gpt-4o", messages, True, None))
|
||||||
|
list(
|
||||||
|
wrapped(
|
||||||
|
llm,
|
||||||
|
"gpt-4o",
|
||||||
|
messages,
|
||||||
|
True,
|
||||||
|
tools_payload,
|
||||||
|
_usage_attachments=usage_attachments,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(captured) == 2
|
||||||
|
assert captured[1]["prompt_tokens"] > captured[0]["prompt_tokens"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_update_token_usage_inserts_with_agent_id_only(monkeypatch):
|
||||||
|
inserted_docs = []
|
||||||
|
|
||||||
|
class FakeCollection:
|
||||||
|
def insert_one(self, doc):
|
||||||
|
inserted_docs.append(doc)
|
||||||
|
|
||||||
|
modules_without_pytest = dict(sys.modules)
|
||||||
|
modules_without_pytest.pop("pytest", None)
|
||||||
|
|
||||||
|
monkeypatch.setattr("application.usage.sys.modules", modules_without_pytest)
|
||||||
|
monkeypatch.setattr("application.usage.usage_collection", FakeCollection())
|
||||||
|
|
||||||
|
update_token_usage(
|
||||||
|
decoded_token=None,
|
||||||
|
user_api_key=None,
|
||||||
|
token_usage={"prompt_tokens": 10, "generated_tokens": 5},
|
||||||
|
agent_id="agent_123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(inserted_docs) == 1
|
||||||
|
assert inserted_docs[0]["agent_id"] == "agent_123"
|
||||||
|
assert inserted_docs[0]["user_id"] is None
|
||||||
|
assert inserted_docs[0]["api_key"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_update_token_usage_skips_when_all_ids_missing(monkeypatch):
|
||||||
|
inserted_docs = []
|
||||||
|
|
||||||
|
class FakeCollection:
|
||||||
|
def insert_one(self, doc):
|
||||||
|
inserted_docs.append(doc)
|
||||||
|
|
||||||
|
modules_without_pytest = dict(sys.modules)
|
||||||
|
modules_without_pytest.pop("pytest", None)
|
||||||
|
|
||||||
|
monkeypatch.setattr("application.usage.sys.modules", modules_without_pytest)
|
||||||
|
monkeypatch.setattr("application.usage.usage_collection", FakeCollection())
|
||||||
|
|
||||||
|
update_token_usage(
|
||||||
|
decoded_token=None,
|
||||||
|
user_api_key=None,
|
||||||
|
token_usage={"prompt_tokens": 10, "generated_tokens": 5},
|
||||||
|
agent_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert inserted_docs == []
|
||||||
Reference in New Issue
Block a user