diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py index 70dbbf20..5d1e1566 100644 --- a/application/retriever/brave_search.py +++ b/application/retriever/brave_search.py @@ -2,7 +2,7 @@ import json from application.retriever.base import BaseRetriever from application.core.settings import settings from application.llm.llm_creator import LLMCreator -from application.utils import count_tokens +from application.utils import num_tokens_from_string from langchain_community.tools import BraveSearch @@ -78,7 +78,7 @@ class BraveRetSearch(BaseRetriever): self.chat_history.reverse() for i in self.chat_history: if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens( + tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string( i["response"] ) if tokens_current_history + tokens_batch < self.token_limit: diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 2b77db34..aef6e503 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -4,7 +4,7 @@ from application.core.settings import settings from application.vectorstore.vector_creator import VectorCreator from application.llm.llm_creator import LLMCreator -from application.utils import count_tokens +from application.utils import num_tokens_from_string class ClassicRAG(BaseRetriever): @@ -98,7 +98,7 @@ class ClassicRAG(BaseRetriever): self.chat_history.reverse() for i in self.chat_history: if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens( + tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string( i["response"] ) if tokens_current_history + tokens_batch < self.token_limit: diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index bee74e24..6d2965f5 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -1,7 +1,7 @@ from application.retriever.base import BaseRetriever from application.core.settings import settings from application.llm.llm_creator import LLMCreator -from application.utils import count_tokens +from application.utils import num_tokens_from_string from langchain_community.tools import DuckDuckGoSearchResults from langchain_community.utilities import DuckDuckGoSearchAPIWrapper @@ -95,7 +95,7 @@ class DuckDuckSearch(BaseRetriever): self.chat_history.reverse() for i in self.chat_history: if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens( + tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string( i["response"] ) if tokens_current_history + tokens_batch < self.token_limit: diff --git a/application/usage.py b/application/usage.py index 1b26e9d7..aba0ec77 100644 --- a/application/usage.py +++ b/application/usage.py @@ -2,7 +2,7 @@ import sys from pymongo import MongoClient from datetime import datetime from application.core.settings import settings -from application.utils import count_tokens +from application.utils import num_tokens_from_string mongo = MongoClient(settings.MONGO_URI) db = mongo["docsgpt"] @@ -24,9 +24,9 @@ def update_token_usage(user_api_key, token_usage): def gen_token_usage(func): def wrapper(self, model, messages, stream, **kwargs): for message in messages: - self.token_usage["prompt_tokens"] += count_tokens(message["content"]) + self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) result = func(self, model, messages, stream, **kwargs) - self.token_usage["generated_tokens"] += count_tokens(result) + self.token_usage["generated_tokens"] += num_tokens_from_string(result) update_token_usage(self.user_api_key, self.token_usage) return result @@ -36,14 +36,14 @@ def gen_token_usage(func): def stream_token_usage(func): def wrapper(self, model, messages, stream, **kwargs): for message in messages: - self.token_usage["prompt_tokens"] += count_tokens(message["content"]) + self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) batch = [] result = func(self, model, messages, stream, **kwargs) for r in result: batch.append(r) yield r for line in batch: - self.token_usage["generated_tokens"] += count_tokens(line) + self.token_usage["generated_tokens"] += num_tokens_from_string(line) update_token_usage(self.user_api_key, self.token_usage) return wrapper diff --git a/application/utils.py b/application/utils.py index 3d9bf520..70a00ce0 100644 --- a/application/utils.py +++ b/application/utils.py @@ -1,6 +1,22 @@ -from transformers import GPT2TokenizerFast +import tiktoken -tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') -tokenizer.model_max_length = 100000 -def count_tokens(string): - return len(tokenizer(string)['input_ids']) \ No newline at end of file +_encoding = None + +def get_encoding(): + global _encoding + if _encoding is None: + _encoding = tiktoken.get_encoding("cl100k_base") + return _encoding + +def num_tokens_from_string(string: str) -> int: + encoding = get_encoding() + num_tokens = len(encoding.encode(string)) + return num_tokens + +def count_tokens_docs(docs): + docs_content = "" + for doc in docs: + docs_content += doc.page_content + + tokens = num_tokens_from_string(docs_content) + return tokens \ No newline at end of file diff --git a/application/worker.py b/application/worker.py index 3105aabe..c315f916 100755 --- a/application/worker.py +++ b/application/worker.py @@ -2,7 +2,6 @@ import os import shutil import string import zipfile -import tiktoken from urllib.parse import urljoin import logging @@ -14,6 +13,7 @@ from application.parser.remote.remote_creator import RemoteCreator from application.parser.open_ai_func import call_openai_api from application.parser.schema.base import Document from application.parser.token_func import group_split +from application.utils import count_tokens_docs # Define a function to extract metadata from a given filename. @@ -212,26 +212,4 @@ def remote_worker(self, source_data, name_job, user, loader, directory="temp"): shutil.rmtree(full_path) - return {"urls": source_data, "name_job": name_job, "user": user, "limited": False} - - -def count_tokens_docs(docs): - # Here we convert the docs list to a string and calculate the number of tokens the string represents. - # docs_content = (" ".join(docs)) - docs_content = "" - for doc in docs: - docs_content += doc.page_content - - tokens, total_price = num_tokens_from_string( - string=docs_content, encoding_name="cl100k_base" - ) - # Here we print the number of tokens and the approx user cost with some visually appealing formatting. - return tokens - - -def num_tokens_from_string(string: str, encoding_name: str) -> int: - # Function to convert string to tokens and estimate user cost. - encoding = tiktoken.get_encoding(encoding_name) - num_tokens = len(encoding.encode(string)) - total_price = (num_tokens / 1000) * 0.0004 - return num_tokens, total_price \ No newline at end of file + return {"urls": source_data, "name_job": name_job, "user": user, "limited": False} \ No newline at end of file