mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-21 20:01:26 +00:00
fix: token calc (#2285)
This commit is contained in:
@@ -23,6 +23,7 @@ class BaseAgent(ABC):
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
agent_id: Optional[str] = None,
|
||||
user_api_key: Optional[str] = None,
|
||||
prompt: str = "",
|
||||
chat_history: Optional[List[Dict]] = None,
|
||||
@@ -40,6 +41,7 @@ class BaseAgent(ABC):
|
||||
self.llm_name = llm_name
|
||||
self.model_id = model_id
|
||||
self.api_key = api_key
|
||||
self.agent_id = agent_id
|
||||
self.user_api_key = user_api_key
|
||||
self.prompt = prompt
|
||||
self.decoded_token = decoded_token or {}
|
||||
@@ -54,6 +56,7 @@ class BaseAgent(ABC):
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
self.retrieved_docs = retrieved_docs or []
|
||||
self.llm_handler = LLMHandlerCreator.create_handler(
|
||||
@@ -263,6 +266,11 @@ class BaseAgent(ABC):
|
||||
tool_config=tool_config,
|
||||
user_id=self.user,
|
||||
)
|
||||
resolved_arguments = (
|
||||
{"query_params": query_params, "headers": headers, "body": body}
|
||||
if tool_data["name"] == "api_tool"
|
||||
else parameters
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
logger.debug(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
@@ -292,11 +300,19 @@ class BaseAgent(ABC):
|
||||
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
|
||||
if artifact_id:
|
||||
tool_call_data["artifact_id"] = artifact_id
|
||||
result_full = str(result)
|
||||
tool_call_data["resolved_arguments"] = resolved_arguments
|
||||
tool_call_data["result_full"] = result_full
|
||||
tool_call_data["result"] = (
|
||||
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
|
||||
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
|
||||
)
|
||||
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
|
||||
stream_tool_call_data = {
|
||||
key: value
|
||||
for key, value in tool_call_data.items()
|
||||
if key not in {"result_full", "resolved_arguments"}
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
|
||||
return result, call_id
|
||||
@@ -304,7 +320,11 @@ class BaseAgent(ABC):
|
||||
def _get_truncated_tool_calls(self):
|
||||
return [
|
||||
{
|
||||
**tool_call,
|
||||
"tool_name": tool_call.get("tool_name"),
|
||||
"call_id": tool_call.get("call_id"),
|
||||
"action_name": tool_call.get("action_name"),
|
||||
"arguments": tool_call.get("arguments"),
|
||||
"artifact_id": tool_call.get("artifact_id"),
|
||||
"result": (
|
||||
f"{str(tool_call['result'])[:50]}..."
|
||||
if len(str(tool_call["result"])) > 50
|
||||
@@ -576,6 +596,9 @@ class BaseAgent(ABC):
|
||||
self._validate_context_size(messages)
|
||||
|
||||
gen_kwargs = {"model": self.model_id, "messages": messages}
|
||||
if self.attachments:
|
||||
# Usage accounting only; stripped before provider invocation.
|
||||
gen_kwargs["_usage_attachments"] = self.attachments
|
||||
|
||||
if (
|
||||
hasattr(self.llm, "_supports_tools")
|
||||
|
||||
@@ -42,6 +42,7 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"agent_id": fields.String(required=False, description="Agent ID"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
@@ -100,6 +101,9 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
agent_id=processor.agent_id,
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
)
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
@@ -46,6 +46,27 @@ class BaseAnswerResource:
|
||||
return missing_fields
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _prepare_tool_calls_for_logging(
|
||||
tool_calls: Optional[List[Dict[str, Any]]], max_chars: int = 10000
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
prepared = []
|
||||
for tool_call in tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
prepared.append({"result": str(tool_call)[:max_chars]})
|
||||
continue
|
||||
|
||||
item = dict(tool_call)
|
||||
for key in ("result", "result_full"):
|
||||
value = item.get(key)
|
||||
if isinstance(value, str) and len(value) > max_chars:
|
||||
item[key] = value[:max_chars]
|
||||
prepared.append(item)
|
||||
return prepared
|
||||
|
||||
def check_usage(self, agent_config: Dict) -> Optional[Response]:
|
||||
"""Check if there is a usage limit and if it is exceeded
|
||||
|
||||
@@ -246,6 +267,7 @@ class BaseAnswerResource:
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
@@ -292,14 +314,20 @@ class BaseAnswerResource:
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
|
||||
getattr(agent, "tool_calls", tool_calls) or tool_calls
|
||||
)
|
||||
|
||||
log_data = {
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls_for_logging,
|
||||
"attachments": attachment_ids,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
@@ -330,6 +358,7 @@ class BaseAnswerResource:
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
|
||||
@@ -42,6 +42,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"agent_id": fields.String(required=False, description="Agent ID"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
@@ -107,7 +108,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
index=data.get("index"),
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
attachment_ids=data.get("attachments", []),
|
||||
agent_id=data.get("agent_id"),
|
||||
agent_id=processor.agent_id,
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
|
||||
@@ -134,6 +134,7 @@ class CompressionOrchestrator:
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
model_id=compression_model,
|
||||
agent_id=conversation.get("agent_id"),
|
||||
)
|
||||
|
||||
# Create compression service with DB update capability
|
||||
|
||||
@@ -90,6 +90,7 @@ class StreamProcessor:
|
||||
self.retriever_config = {}
|
||||
self.is_shared_usage = False
|
||||
self.shared_token = None
|
||||
self.agent_id = self.data.get("agent_id")
|
||||
self.model_id: Optional[str] = None
|
||||
self.conversation_service = ConversationService()
|
||||
self.compression_orchestrator = CompressionOrchestrator(
|
||||
@@ -355,10 +356,13 @@ class StreamProcessor:
|
||||
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
|
||||
agent_id, self.initial_user_id
|
||||
)
|
||||
self.agent_id = str(agent_id) if agent_id else None
|
||||
|
||||
api_key = self.data.get("api_key")
|
||||
if api_key:
|
||||
data_key = self._get_data_from_api_key(api_key)
|
||||
if data_key.get("_id"):
|
||||
self.agent_id = str(data_key.get("_id"))
|
||||
self.agent_config.update(
|
||||
{
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
@@ -387,6 +391,8 @@ class StreamProcessor:
|
||||
self.retriever_config["chunks"] = 2
|
||||
elif self.agent_key:
|
||||
data_key = self._get_data_from_api_key(self.agent_key)
|
||||
if data_key.get("_id"):
|
||||
self.agent_id = str(data_key.get("_id"))
|
||||
self.agent_config.update(
|
||||
{
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
@@ -459,6 +465,7 @@ class StreamProcessor:
|
||||
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||||
model_id=self.model_id,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
agent_id=self.agent_id,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
@@ -754,6 +761,7 @@ class StreamProcessor:
|
||||
"llm_name": provider or settings.LLM_PROVIDER,
|
||||
"model_id": self.model_id,
|
||||
"api_key": system_api_key,
|
||||
"agent_id": self.agent_id,
|
||||
"user_api_key": self.agent_config["user_api_key"],
|
||||
"prompt": rendered_prompt,
|
||||
"chat_history": self.history,
|
||||
|
||||
@@ -13,10 +13,12 @@ class BaseLLM(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoded_token=None,
|
||||
agent_id=None,
|
||||
model_id=None,
|
||||
base_url=None,
|
||||
):
|
||||
self.decoded_token = decoded_token
|
||||
self.agent_id = str(agent_id) if agent_id else None
|
||||
self.model_id = model_id
|
||||
self.base_url = base_url
|
||||
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
@@ -33,9 +35,10 @@ class BaseLLM(ABC):
|
||||
self._fallback_llm = LLMCreator.create_llm(
|
||||
settings.FALLBACK_LLM_PROVIDER,
|
||||
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
|
||||
user_api_key=None,
|
||||
user_api_key=getattr(self, "user_api_key", None),
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=settings.FALLBACK_LLM_NAME,
|
||||
agent_id=self.agent_id,
|
||||
)
|
||||
logger.info(
|
||||
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
|
||||
|
||||
@@ -13,7 +13,7 @@ class GoogleLLM(BaseLLM):
|
||||
def __init__(
|
||||
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(decoded_token=decoded_token, *args, **kwargs)
|
||||
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
|
||||
@@ -567,6 +567,7 @@ class LLMHandler(ABC):
|
||||
getattr(agent, "user_api_key", None),
|
||||
getattr(agent, "decoded_token", None),
|
||||
model_id=compression_model,
|
||||
agent_id=getattr(agent, "agent_id", None),
|
||||
)
|
||||
|
||||
# Create service without DB persistence capability
|
||||
|
||||
@@ -31,7 +31,15 @@ class LLMCreator:
|
||||
|
||||
@classmethod
|
||||
def create_llm(
|
||||
cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs
|
||||
cls,
|
||||
type,
|
||||
api_key,
|
||||
user_api_key,
|
||||
decoded_token,
|
||||
model_id=None,
|
||||
agent_id=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
|
||||
@@ -49,6 +57,7 @@ class LLMCreator:
|
||||
user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
base_url=base_url,
|
||||
*args,
|
||||
**kwargs,
|
||||
|
||||
@@ -18,6 +18,7 @@ class ClassicRAG(BaseRetriever):
|
||||
doc_token_limit=50000,
|
||||
model_id="docsgpt-local",
|
||||
user_api_key=None,
|
||||
agent_id=None,
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
decoded_token=None,
|
||||
@@ -43,6 +44,7 @@ class ClassicRAG(BaseRetriever):
|
||||
self.model_id = model_id
|
||||
self.doc_token_limit = doc_token_limit
|
||||
self.user_api_key = user_api_key
|
||||
self.agent_id = agent_id
|
||||
self.llm_name = llm_name
|
||||
self.api_key = api_key
|
||||
self.llm = LLMCreator.create_llm(
|
||||
@@ -50,6 +52,7 @@ class ClassicRAG(BaseRetriever):
|
||||
api_key=self.api_key,
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
agent_id=self.agent_id,
|
||||
)
|
||||
|
||||
if "active_docs" in source and source["active_docs"] is not None:
|
||||
|
||||
@@ -1,22 +1,104 @@
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.utils import num_tokens_from_object_or_list, num_tokens_from_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
usage_collection = db["token_usage"]
|
||||
|
||||
|
||||
def update_token_usage(decoded_token, user_api_key, token_usage):
|
||||
def _serialize_for_token_count(value):
|
||||
"""Normalize payloads into token-countable primitives."""
|
||||
if isinstance(value, str):
|
||||
# Avoid counting large binary payloads in data URLs as text tokens.
|
||||
if value.startswith("data:") and ";base64," in value:
|
||||
return ""
|
||||
return value
|
||||
|
||||
if value is None:
|
||||
return ""
|
||||
|
||||
if isinstance(value, list):
|
||||
return [_serialize_for_token_count(item) for item in value]
|
||||
|
||||
if isinstance(value, dict):
|
||||
serialized = {}
|
||||
for key, raw in value.items():
|
||||
key_lower = str(key).lower()
|
||||
|
||||
# Skip raw binary-like fields; keep textual tool-call fields.
|
||||
if key_lower in {"data", "base64", "image_data"} and isinstance(raw, str):
|
||||
continue
|
||||
if key_lower == "url" and isinstance(raw, str) and ";base64," in raw:
|
||||
continue
|
||||
|
||||
serialized[key] = _serialize_for_token_count(raw)
|
||||
return serialized
|
||||
|
||||
if hasattr(value, "model_dump") and callable(getattr(value, "model_dump")):
|
||||
return _serialize_for_token_count(value.model_dump())
|
||||
if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")):
|
||||
return _serialize_for_token_count(value.to_dict())
|
||||
if hasattr(value, "__dict__"):
|
||||
return _serialize_for_token_count(vars(value))
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
def _count_tokens(value):
|
||||
serialized = _serialize_for_token_count(value)
|
||||
if isinstance(serialized, str):
|
||||
return num_tokens_from_string(serialized)
|
||||
return num_tokens_from_object_or_list(serialized)
|
||||
|
||||
|
||||
def _count_prompt_tokens(messages, tools=None, usage_attachments=None, **kwargs):
|
||||
prompt_tokens = 0
|
||||
|
||||
for message in messages or []:
|
||||
if not isinstance(message, dict):
|
||||
prompt_tokens += _count_tokens(message)
|
||||
continue
|
||||
|
||||
prompt_tokens += _count_tokens(message.get("content"))
|
||||
|
||||
# Include tool-related message fields for providers that use OpenAI-native format.
|
||||
prompt_tokens += _count_tokens(message.get("tool_calls"))
|
||||
prompt_tokens += _count_tokens(message.get("tool_call_id"))
|
||||
prompt_tokens += _count_tokens(message.get("function_call"))
|
||||
prompt_tokens += _count_tokens(message.get("function_response"))
|
||||
|
||||
# Count tool schema payload passed to the model.
|
||||
prompt_tokens += _count_tokens(tools)
|
||||
|
||||
# Count structured-output/schema payloads when provided.
|
||||
prompt_tokens += _count_tokens(kwargs.get("response_format"))
|
||||
prompt_tokens += _count_tokens(kwargs.get("response_schema"))
|
||||
|
||||
# Optional usage-only attachment context (not forwarded to provider).
|
||||
prompt_tokens += _count_tokens(usage_attachments)
|
||||
|
||||
return prompt_tokens
|
||||
|
||||
|
||||
def update_token_usage(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
if "pytest" in sys.modules:
|
||||
return
|
||||
if decoded_token:
|
||||
user_id = decoded_token["sub"]
|
||||
else:
|
||||
user_id = None
|
||||
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
|
||||
normalized_agent_id = str(agent_id) if agent_id else None
|
||||
|
||||
if not user_id and not user_api_key and not normalized_agent_id:
|
||||
logger.warning(
|
||||
"Skipping token usage insert: missing user_id, api_key, and agent_id"
|
||||
)
|
||||
return
|
||||
|
||||
usage_data = {
|
||||
"user_id": user_id,
|
||||
"api_key": user_api_key,
|
||||
@@ -24,24 +106,31 @@ def update_token_usage(decoded_token, user_api_key, token_usage):
|
||||
"generated_tokens": token_usage["generated_tokens"],
|
||||
"timestamp": datetime.now(),
|
||||
}
|
||||
if normalized_agent_id:
|
||||
usage_data["agent_id"] = normalized_agent_id
|
||||
usage_collection.insert_one(usage_data)
|
||||
|
||||
|
||||
def gen_token_usage(func):
|
||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||
for message in messages:
|
||||
if message["content"]:
|
||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(
|
||||
message["content"]
|
||||
)
|
||||
usage_attachments = kwargs.pop("_usage_attachments", None)
|
||||
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
call_usage["prompt_tokens"] += _count_prompt_tokens(
|
||||
messages,
|
||||
tools=tools,
|
||||
usage_attachments=usage_attachments,
|
||||
**kwargs,
|
||||
)
|
||||
result = func(self, model, messages, stream, tools, **kwargs)
|
||||
if isinstance(result, str):
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
|
||||
else:
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(
|
||||
result
|
||||
)
|
||||
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
|
||||
call_usage["generated_tokens"] += _count_tokens(result)
|
||||
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||
update_token_usage(
|
||||
self.decoded_token,
|
||||
self.user_api_key,
|
||||
call_usage,
|
||||
getattr(self, "agent_id", None),
|
||||
)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
@@ -49,17 +138,28 @@ def gen_token_usage(func):
|
||||
|
||||
def stream_token_usage(func):
|
||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||
for message in messages:
|
||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(
|
||||
message["content"]
|
||||
)
|
||||
usage_attachments = kwargs.pop("_usage_attachments", None)
|
||||
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
call_usage["prompt_tokens"] += _count_prompt_tokens(
|
||||
messages,
|
||||
tools=tools,
|
||||
usage_attachments=usage_attachments,
|
||||
**kwargs,
|
||||
)
|
||||
batch = []
|
||||
result = func(self, model, messages, stream, tools, **kwargs)
|
||||
for r in result:
|
||||
batch.append(r)
|
||||
yield r
|
||||
for line in batch:
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
|
||||
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
|
||||
call_usage["generated_tokens"] += _count_tokens(line)
|
||||
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||
update_token_usage(
|
||||
self.decoded_token,
|
||||
self.user_api_key,
|
||||
call_usage,
|
||||
getattr(self, "agent_id", None),
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -322,6 +322,7 @@ def run_agent_logic(agent_config, input_data):
|
||||
chunks = int(agent_config.get("chunks", 2))
|
||||
prompt_id = agent_config.get("prompt_id", "default")
|
||||
user_api_key = agent_config["key"]
|
||||
agent_id = str(agent_config.get("_id")) if agent_config.get("_id") else None
|
||||
agent_type = agent_config.get("agent_type", "classic")
|
||||
decoded_token = {"sub": agent_config.get("user")}
|
||||
json_schema = agent_config.get("json_schema")
|
||||
@@ -352,6 +353,7 @@ def run_agent_logic(agent_config, input_data):
|
||||
doc_token_limit=doc_token_limit,
|
||||
model_id=model_id,
|
||||
user_api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
@@ -370,6 +372,7 @@ def run_agent_logic(agent_config, input_data):
|
||||
llm_name=provider or settings.LLM_PROVIDER,
|
||||
model_id=model_id,
|
||||
api_key=system_api_key,
|
||||
agent_id=agent_id,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=[],
|
||||
|
||||
@@ -199,6 +199,7 @@ class TestStreamProcessorAgentConfiguration:
|
||||
try:
|
||||
processor._configure_agent()
|
||||
assert processor.agent_config is not None
|
||||
assert processor.agent_id == str(agent_id)
|
||||
except Exception as e:
|
||||
assert "Invalid API Key" in str(e)
|
||||
|
||||
@@ -211,6 +212,7 @@ class TestStreamProcessorAgentConfiguration:
|
||||
processor._configure_agent()
|
||||
|
||||
assert isinstance(processor.agent_config, dict)
|
||||
assert processor.agent_id is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
||||
326
tests/test_usage.py
Normal file
326
tests/test_usage.py
Normal file
@@ -0,0 +1,326 @@
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from application.usage import (
|
||||
_count_tokens,
|
||||
gen_token_usage,
|
||||
stream_token_usage,
|
||||
update_token_usage,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_count_tokens_includes_tool_call_payloads():
|
||||
payload = [
|
||||
{
|
||||
"function_call": {
|
||||
"name": "search_docs",
|
||||
"args": {"query": "pricing limits"},
|
||||
"call_id": "call_1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"function_response": {
|
||||
"name": "search_docs",
|
||||
"response": {"result": "Found 3 docs"},
|
||||
"call_id": "call_1",
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
assert _count_tokens(payload) > 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gen_token_usage_counts_structured_tool_content(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
captured["decoded_token"] = decoded_token
|
||||
captured["user_api_key"] = user_api_key
|
||||
captured["token_usage"] = token_usage.copy()
|
||||
captured["agent_id"] = agent_id
|
||||
|
||||
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
|
||||
|
||||
class DummyLLM:
|
||||
decoded_token = {"sub": "user_123"}
|
||||
user_api_key = "api_key_123"
|
||||
agent_id = "agent_123"
|
||||
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
|
||||
@gen_token_usage
|
||||
def wrapped(self, model, messages, stream, tools, **kwargs):
|
||||
_ = (model, messages, stream, tools, kwargs)
|
||||
return {
|
||||
"tool_calls": [
|
||||
{"name": "read_webpage", "arguments": {"url": "https://example.com"}}
|
||||
]
|
||||
}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"function_call": {
|
||||
"name": "search_docs",
|
||||
"args": {"query": "pricing"},
|
||||
"call_id": "1",
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": "search_docs",
|
||||
"response": {"result": "Found docs"},
|
||||
"call_id": "1",
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
llm = DummyLLM()
|
||||
wrapped(llm, "gpt-4o", messages, False, None)
|
||||
|
||||
assert captured["decoded_token"] == {"sub": "user_123"}
|
||||
assert captured["user_api_key"] == "api_key_123"
|
||||
assert captured["agent_id"] == "agent_123"
|
||||
assert captured["token_usage"]["prompt_tokens"] > 0
|
||||
assert captured["token_usage"]["generated_tokens"] > 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_stream_token_usage_counts_tool_call_chunks(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
captured["token_usage"] = token_usage.copy()
|
||||
captured["agent_id"] = agent_id
|
||||
|
||||
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
|
||||
|
||||
class ToolChunk:
|
||||
def model_dump(self):
|
||||
return {
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location":"Seattle"}',
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
class DummyLLM:
|
||||
decoded_token = {"sub": "user_123"}
|
||||
user_api_key = "api_key_123"
|
||||
agent_id = "agent_123"
|
||||
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
|
||||
@stream_token_usage
|
||||
def wrapped(self, model, messages, stream, tools, **kwargs):
|
||||
_ = (model, messages, stream, tools, kwargs)
|
||||
yield ToolChunk()
|
||||
yield "done"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"function_call": {
|
||||
"name": "get_weather",
|
||||
"args": {"location": "Seattle"},
|
||||
"call_id": "1",
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
llm = DummyLLM()
|
||||
list(wrapped(llm, "gpt-4o", messages, True, None))
|
||||
|
||||
assert captured["agent_id"] == "agent_123"
|
||||
assert captured["token_usage"]["prompt_tokens"] > 0
|
||||
assert captured["token_usage"]["generated_tokens"] > 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gen_token_usage_counts_tools_and_image_inputs(monkeypatch):
|
||||
captured = []
|
||||
|
||||
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
_ = (decoded_token, user_api_key, agent_id)
|
||||
captured.append(token_usage.copy())
|
||||
|
||||
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
|
||||
|
||||
class DummyLLM:
|
||||
decoded_token = {"sub": "user_123"}
|
||||
user_api_key = "api_key_123"
|
||||
agent_id = "agent_123"
|
||||
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
|
||||
@gen_token_usage
|
||||
def wrapped(self, model, messages, stream, tools, **kwargs):
|
||||
_ = (model, messages, stream, tools, kwargs)
|
||||
return "ok"
|
||||
|
||||
messages = [{"role": "user", "content": "What is in this image?"}]
|
||||
tools_payload = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "describe_image",
|
||||
"description": "Describe image content",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"detail": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
usage_attachments = [
|
||||
{
|
||||
"mime_type": "image/png",
|
||||
"path": "attachments/example.png",
|
||||
"data": "abc123",
|
||||
}
|
||||
]
|
||||
|
||||
llm = DummyLLM()
|
||||
wrapped(llm, "gpt-4o", messages, False, None)
|
||||
wrapped(
|
||||
llm,
|
||||
"gpt-4o",
|
||||
messages,
|
||||
False,
|
||||
tools_payload,
|
||||
_usage_attachments=usage_attachments,
|
||||
)
|
||||
|
||||
assert len(captured) == 2
|
||||
assert captured[1]["prompt_tokens"] > captured[0]["prompt_tokens"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_stream_token_usage_counts_tools_and_image_inputs(monkeypatch):
|
||||
captured = []
|
||||
|
||||
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
_ = (decoded_token, user_api_key, agent_id)
|
||||
captured.append(token_usage.copy())
|
||||
|
||||
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
|
||||
|
||||
class DummyLLM:
|
||||
decoded_token = {"sub": "user_123"}
|
||||
user_api_key = "api_key_123"
|
||||
agent_id = "agent_123"
|
||||
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
|
||||
@stream_token_usage
|
||||
def wrapped(self, model, messages, stream, tools, **kwargs):
|
||||
_ = (model, messages, stream, tools, kwargs)
|
||||
yield "ok"
|
||||
|
||||
messages = [{"role": "user", "content": "What is in this image?"}]
|
||||
tools_payload = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "describe_image",
|
||||
"description": "Describe image content",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"detail": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
usage_attachments = [
|
||||
{
|
||||
"mime_type": "image/png",
|
||||
"path": "attachments/example.png",
|
||||
"data": "abc123",
|
||||
}
|
||||
]
|
||||
|
||||
llm = DummyLLM()
|
||||
list(wrapped(llm, "gpt-4o", messages, True, None))
|
||||
list(
|
||||
wrapped(
|
||||
llm,
|
||||
"gpt-4o",
|
||||
messages,
|
||||
True,
|
||||
tools_payload,
|
||||
_usage_attachments=usage_attachments,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(captured) == 2
|
||||
assert captured[1]["prompt_tokens"] > captured[0]["prompt_tokens"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_update_token_usage_inserts_with_agent_id_only(monkeypatch):
|
||||
inserted_docs = []
|
||||
|
||||
class FakeCollection:
|
||||
def insert_one(self, doc):
|
||||
inserted_docs.append(doc)
|
||||
|
||||
modules_without_pytest = dict(sys.modules)
|
||||
modules_without_pytest.pop("pytest", None)
|
||||
|
||||
monkeypatch.setattr("application.usage.sys.modules", modules_without_pytest)
|
||||
monkeypatch.setattr("application.usage.usage_collection", FakeCollection())
|
||||
|
||||
update_token_usage(
|
||||
decoded_token=None,
|
||||
user_api_key=None,
|
||||
token_usage={"prompt_tokens": 10, "generated_tokens": 5},
|
||||
agent_id="agent_123",
|
||||
)
|
||||
|
||||
assert len(inserted_docs) == 1
|
||||
assert inserted_docs[0]["agent_id"] == "agent_123"
|
||||
assert inserted_docs[0]["user_id"] is None
|
||||
assert inserted_docs[0]["api_key"] is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_update_token_usage_skips_when_all_ids_missing(monkeypatch):
|
||||
inserted_docs = []
|
||||
|
||||
class FakeCollection:
|
||||
def insert_one(self, doc):
|
||||
inserted_docs.append(doc)
|
||||
|
||||
modules_without_pytest = dict(sys.modules)
|
||||
modules_without_pytest.pop("pytest", None)
|
||||
|
||||
monkeypatch.setattr("application.usage.sys.modules", modules_without_pytest)
|
||||
monkeypatch.setattr("application.usage.usage_collection", FakeCollection())
|
||||
|
||||
update_token_usage(
|
||||
decoded_token=None,
|
||||
user_api_key=None,
|
||||
token_usage={"prompt_tokens": 10, "generated_tokens": 5},
|
||||
agent_id=None,
|
||||
)
|
||||
|
||||
assert inserted_docs == []
|
||||
Reference in New Issue
Block a user