mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-23 04:41:47 +00:00
fix: token calc (#2285)
This commit is contained in:
@@ -1,22 +1,104 @@
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.utils import num_tokens_from_object_or_list, num_tokens_from_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
usage_collection = db["token_usage"]
|
||||
|
||||
|
||||
def update_token_usage(decoded_token, user_api_key, token_usage):
|
||||
def _serialize_for_token_count(value):
|
||||
"""Normalize payloads into token-countable primitives."""
|
||||
if isinstance(value, str):
|
||||
# Avoid counting large binary payloads in data URLs as text tokens.
|
||||
if value.startswith("data:") and ";base64," in value:
|
||||
return ""
|
||||
return value
|
||||
|
||||
if value is None:
|
||||
return ""
|
||||
|
||||
if isinstance(value, list):
|
||||
return [_serialize_for_token_count(item) for item in value]
|
||||
|
||||
if isinstance(value, dict):
|
||||
serialized = {}
|
||||
for key, raw in value.items():
|
||||
key_lower = str(key).lower()
|
||||
|
||||
# Skip raw binary-like fields; keep textual tool-call fields.
|
||||
if key_lower in {"data", "base64", "image_data"} and isinstance(raw, str):
|
||||
continue
|
||||
if key_lower == "url" and isinstance(raw, str) and ";base64," in raw:
|
||||
continue
|
||||
|
||||
serialized[key] = _serialize_for_token_count(raw)
|
||||
return serialized
|
||||
|
||||
if hasattr(value, "model_dump") and callable(getattr(value, "model_dump")):
|
||||
return _serialize_for_token_count(value.model_dump())
|
||||
if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")):
|
||||
return _serialize_for_token_count(value.to_dict())
|
||||
if hasattr(value, "__dict__"):
|
||||
return _serialize_for_token_count(vars(value))
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
def _count_tokens(value):
|
||||
serialized = _serialize_for_token_count(value)
|
||||
if isinstance(serialized, str):
|
||||
return num_tokens_from_string(serialized)
|
||||
return num_tokens_from_object_or_list(serialized)
|
||||
|
||||
|
||||
def _count_prompt_tokens(messages, tools=None, usage_attachments=None, **kwargs):
|
||||
prompt_tokens = 0
|
||||
|
||||
for message in messages or []:
|
||||
if not isinstance(message, dict):
|
||||
prompt_tokens += _count_tokens(message)
|
||||
continue
|
||||
|
||||
prompt_tokens += _count_tokens(message.get("content"))
|
||||
|
||||
# Include tool-related message fields for providers that use OpenAI-native format.
|
||||
prompt_tokens += _count_tokens(message.get("tool_calls"))
|
||||
prompt_tokens += _count_tokens(message.get("tool_call_id"))
|
||||
prompt_tokens += _count_tokens(message.get("function_call"))
|
||||
prompt_tokens += _count_tokens(message.get("function_response"))
|
||||
|
||||
# Count tool schema payload passed to the model.
|
||||
prompt_tokens += _count_tokens(tools)
|
||||
|
||||
# Count structured-output/schema payloads when provided.
|
||||
prompt_tokens += _count_tokens(kwargs.get("response_format"))
|
||||
prompt_tokens += _count_tokens(kwargs.get("response_schema"))
|
||||
|
||||
# Optional usage-only attachment context (not forwarded to provider).
|
||||
prompt_tokens += _count_tokens(usage_attachments)
|
||||
|
||||
return prompt_tokens
|
||||
|
||||
|
||||
def update_token_usage(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
if "pytest" in sys.modules:
|
||||
return
|
||||
if decoded_token:
|
||||
user_id = decoded_token["sub"]
|
||||
else:
|
||||
user_id = None
|
||||
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
|
||||
normalized_agent_id = str(agent_id) if agent_id else None
|
||||
|
||||
if not user_id and not user_api_key and not normalized_agent_id:
|
||||
logger.warning(
|
||||
"Skipping token usage insert: missing user_id, api_key, and agent_id"
|
||||
)
|
||||
return
|
||||
|
||||
usage_data = {
|
||||
"user_id": user_id,
|
||||
"api_key": user_api_key,
|
||||
@@ -24,24 +106,31 @@ def update_token_usage(decoded_token, user_api_key, token_usage):
|
||||
"generated_tokens": token_usage["generated_tokens"],
|
||||
"timestamp": datetime.now(),
|
||||
}
|
||||
if normalized_agent_id:
|
||||
usage_data["agent_id"] = normalized_agent_id
|
||||
usage_collection.insert_one(usage_data)
|
||||
|
||||
|
||||
def gen_token_usage(func):
|
||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||
for message in messages:
|
||||
if message["content"]:
|
||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(
|
||||
message["content"]
|
||||
)
|
||||
usage_attachments = kwargs.pop("_usage_attachments", None)
|
||||
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
call_usage["prompt_tokens"] += _count_prompt_tokens(
|
||||
messages,
|
||||
tools=tools,
|
||||
usage_attachments=usage_attachments,
|
||||
**kwargs,
|
||||
)
|
||||
result = func(self, model, messages, stream, tools, **kwargs)
|
||||
if isinstance(result, str):
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
|
||||
else:
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(
|
||||
result
|
||||
)
|
||||
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
|
||||
call_usage["generated_tokens"] += _count_tokens(result)
|
||||
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||
update_token_usage(
|
||||
self.decoded_token,
|
||||
self.user_api_key,
|
||||
call_usage,
|
||||
getattr(self, "agent_id", None),
|
||||
)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
@@ -49,17 +138,28 @@ def gen_token_usage(func):
|
||||
|
||||
def stream_token_usage(func):
|
||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||
for message in messages:
|
||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(
|
||||
message["content"]
|
||||
)
|
||||
usage_attachments = kwargs.pop("_usage_attachments", None)
|
||||
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
call_usage["prompt_tokens"] += _count_prompt_tokens(
|
||||
messages,
|
||||
tools=tools,
|
||||
usage_attachments=usage_attachments,
|
||||
**kwargs,
|
||||
)
|
||||
batch = []
|
||||
result = func(self, model, messages, stream, tools, **kwargs)
|
||||
for r in result:
|
||||
batch.append(r)
|
||||
yield r
|
||||
for line in batch:
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
|
||||
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
|
||||
call_usage["generated_tokens"] += _count_tokens(line)
|
||||
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||
update_token_usage(
|
||||
self.decoded_token,
|
||||
self.user_api_key,
|
||||
call_usage,
|
||||
getattr(self, "agent_id", None),
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
Reference in New Issue
Block a user