feat: better token counter

This commit is contained in:
Alex
2024-08-31 17:07:40 +01:00
parent c49b7613e0
commit d9309ebc6e
6 changed files with 34 additions and 40 deletions

View File

@@ -2,7 +2,7 @@ import json
from application.retriever.base import BaseRetriever from application.retriever.base import BaseRetriever
from application.core.settings import settings from application.core.settings import settings
from application.llm.llm_creator import LLMCreator from application.llm.llm_creator import LLMCreator
from application.utils import count_tokens from application.utils import num_tokens_from_string
from langchain_community.tools import BraveSearch from langchain_community.tools import BraveSearch
@@ -78,7 +78,7 @@ class BraveRetSearch(BaseRetriever):
self.chat_history.reverse() self.chat_history.reverse()
for i in self.chat_history: for i in self.chat_history:
if "prompt" in i and "response" in i: if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens( tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"] i["response"]
) )
if tokens_current_history + tokens_batch < self.token_limit: if tokens_current_history + tokens_batch < self.token_limit:

View File

@@ -4,7 +4,7 @@ from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator from application.llm.llm_creator import LLMCreator
from application.utils import count_tokens from application.utils import num_tokens_from_string
class ClassicRAG(BaseRetriever): class ClassicRAG(BaseRetriever):
@@ -98,7 +98,7 @@ class ClassicRAG(BaseRetriever):
self.chat_history.reverse() self.chat_history.reverse()
for i in self.chat_history: for i in self.chat_history:
if "prompt" in i and "response" in i: if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens( tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"] i["response"]
) )
if tokens_current_history + tokens_batch < self.token_limit: if tokens_current_history + tokens_batch < self.token_limit:

View File

@@ -1,7 +1,7 @@
from application.retriever.base import BaseRetriever from application.retriever.base import BaseRetriever
from application.core.settings import settings from application.core.settings import settings
from application.llm.llm_creator import LLMCreator from application.llm.llm_creator import LLMCreator
from application.utils import count_tokens from application.utils import num_tokens_from_string
from langchain_community.tools import DuckDuckGoSearchResults from langchain_community.tools import DuckDuckGoSearchResults
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
@@ -95,7 +95,7 @@ class DuckDuckSearch(BaseRetriever):
self.chat_history.reverse() self.chat_history.reverse()
for i in self.chat_history: for i in self.chat_history:
if "prompt" in i and "response" in i: if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens( tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"] i["response"]
) )
if tokens_current_history + tokens_batch < self.token_limit: if tokens_current_history + tokens_batch < self.token_limit:

View File

@@ -2,7 +2,7 @@ import sys
from pymongo import MongoClient from pymongo import MongoClient
from datetime import datetime from datetime import datetime
from application.core.settings import settings from application.core.settings import settings
from application.utils import count_tokens from application.utils import num_tokens_from_string
mongo = MongoClient(settings.MONGO_URI) mongo = MongoClient(settings.MONGO_URI)
db = mongo["docsgpt"] db = mongo["docsgpt"]
@@ -24,9 +24,9 @@ def update_token_usage(user_api_key, token_usage):
def gen_token_usage(func): def gen_token_usage(func):
def wrapper(self, model, messages, stream, **kwargs): def wrapper(self, model, messages, stream, **kwargs):
for message in messages: for message in messages:
self.token_usage["prompt_tokens"] += count_tokens(message["content"]) self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
result = func(self, model, messages, stream, **kwargs) result = func(self, model, messages, stream, **kwargs)
self.token_usage["generated_tokens"] += count_tokens(result) self.token_usage["generated_tokens"] += num_tokens_from_string(result)
update_token_usage(self.user_api_key, self.token_usage) update_token_usage(self.user_api_key, self.token_usage)
return result return result
@@ -36,14 +36,14 @@ def gen_token_usage(func):
def stream_token_usage(func): def stream_token_usage(func):
def wrapper(self, model, messages, stream, **kwargs): def wrapper(self, model, messages, stream, **kwargs):
for message in messages: for message in messages:
self.token_usage["prompt_tokens"] += count_tokens(message["content"]) self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
batch = [] batch = []
result = func(self, model, messages, stream, **kwargs) result = func(self, model, messages, stream, **kwargs)
for r in result: for r in result:
batch.append(r) batch.append(r)
yield r yield r
for line in batch: for line in batch:
self.token_usage["generated_tokens"] += count_tokens(line) self.token_usage["generated_tokens"] += num_tokens_from_string(line)
update_token_usage(self.user_api_key, self.token_usage) update_token_usage(self.user_api_key, self.token_usage)
return wrapper return wrapper

View File

@@ -1,6 +1,22 @@
from transformers import GPT2TokenizerFast import tiktoken
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') _encoding = None
tokenizer.model_max_length = 100000
def count_tokens(string): def get_encoding():
return len(tokenizer(string)['input_ids']) global _encoding
if _encoding is None:
_encoding = tiktoken.get_encoding("cl100k_base")
return _encoding
def num_tokens_from_string(string: str) -> int:
encoding = get_encoding()
num_tokens = len(encoding.encode(string))
return num_tokens
def count_tokens_docs(docs):
docs_content = ""
for doc in docs:
docs_content += doc.page_content
tokens = num_tokens_from_string(docs_content)
return tokens

View File

@@ -2,7 +2,6 @@ import os
import shutil import shutil
import string import string
import zipfile import zipfile
import tiktoken
from urllib.parse import urljoin from urllib.parse import urljoin
import logging import logging
@@ -14,6 +13,7 @@ from application.parser.remote.remote_creator import RemoteCreator
from application.parser.open_ai_func import call_openai_api from application.parser.open_ai_func import call_openai_api
from application.parser.schema.base import Document from application.parser.schema.base import Document
from application.parser.token_func import group_split from application.parser.token_func import group_split
from application.utils import count_tokens_docs
# Define a function to extract metadata from a given filename. # Define a function to extract metadata from a given filename.
@@ -213,25 +213,3 @@ def remote_worker(self, source_data, name_job, user, loader, directory="temp"):
shutil.rmtree(full_path) shutil.rmtree(full_path)
return {"urls": source_data, "name_job": name_job, "user": user, "limited": False} return {"urls": source_data, "name_job": name_job, "user": user, "limited": False}
def count_tokens_docs(docs):
# Here we convert the docs list to a string and calculate the number of tokens the string represents.
# docs_content = (" ".join(docs))
docs_content = ""
for doc in docs:
docs_content += doc.page_content
tokens, total_price = num_tokens_from_string(
string=docs_content, encoding_name="cl100k_base"
)
# Here we print the number of tokens and the approx user cost with some visually appealing formatting.
return tokens
def num_tokens_from_string(string: str, encoding_name: str) -> int:
# Function to convert string to tokens and estimate user cost.
encoding = tiktoken.get_encoding(encoding_name)
num_tokens = len(encoding.encode(string))
total_price = (num_tokens / 1000) * 0.0004
return num_tokens, total_price