mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
feat: better token counter
This commit is contained in:
@@ -2,7 +2,7 @@ import json
|
|||||||
from application.retriever.base import BaseRetriever
|
from application.retriever.base import BaseRetriever
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.llm.llm_creator import LLMCreator
|
from application.llm.llm_creator import LLMCreator
|
||||||
from application.utils import count_tokens
|
from application.utils import num_tokens_from_string
|
||||||
from langchain_community.tools import BraveSearch
|
from langchain_community.tools import BraveSearch
|
||||||
|
|
||||||
|
|
||||||
@@ -78,7 +78,7 @@ class BraveRetSearch(BaseRetriever):
|
|||||||
self.chat_history.reverse()
|
self.chat_history.reverse()
|
||||||
for i in self.chat_history:
|
for i in self.chat_history:
|
||||||
if "prompt" in i and "response" in i:
|
if "prompt" in i and "response" in i:
|
||||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
|
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
|
||||||
i["response"]
|
i["response"]
|
||||||
)
|
)
|
||||||
if tokens_current_history + tokens_batch < self.token_limit:
|
if tokens_current_history + tokens_batch < self.token_limit:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from application.core.settings import settings
|
|||||||
from application.vectorstore.vector_creator import VectorCreator
|
from application.vectorstore.vector_creator import VectorCreator
|
||||||
from application.llm.llm_creator import LLMCreator
|
from application.llm.llm_creator import LLMCreator
|
||||||
|
|
||||||
from application.utils import count_tokens
|
from application.utils import num_tokens_from_string
|
||||||
|
|
||||||
|
|
||||||
class ClassicRAG(BaseRetriever):
|
class ClassicRAG(BaseRetriever):
|
||||||
@@ -98,7 +98,7 @@ class ClassicRAG(BaseRetriever):
|
|||||||
self.chat_history.reverse()
|
self.chat_history.reverse()
|
||||||
for i in self.chat_history:
|
for i in self.chat_history:
|
||||||
if "prompt" in i and "response" in i:
|
if "prompt" in i and "response" in i:
|
||||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
|
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
|
||||||
i["response"]
|
i["response"]
|
||||||
)
|
)
|
||||||
if tokens_current_history + tokens_batch < self.token_limit:
|
if tokens_current_history + tokens_batch < self.token_limit:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from application.retriever.base import BaseRetriever
|
from application.retriever.base import BaseRetriever
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.llm.llm_creator import LLMCreator
|
from application.llm.llm_creator import LLMCreator
|
||||||
from application.utils import count_tokens
|
from application.utils import num_tokens_from_string
|
||||||
from langchain_community.tools import DuckDuckGoSearchResults
|
from langchain_community.tools import DuckDuckGoSearchResults
|
||||||
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
|
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
|
||||||
|
|
||||||
@@ -95,7 +95,7 @@ class DuckDuckSearch(BaseRetriever):
|
|||||||
self.chat_history.reverse()
|
self.chat_history.reverse()
|
||||||
for i in self.chat_history:
|
for i in self.chat_history:
|
||||||
if "prompt" in i and "response" in i:
|
if "prompt" in i and "response" in i:
|
||||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
|
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
|
||||||
i["response"]
|
i["response"]
|
||||||
)
|
)
|
||||||
if tokens_current_history + tokens_batch < self.token_limit:
|
if tokens_current_history + tokens_batch < self.token_limit:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import sys
|
|||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.utils import count_tokens
|
from application.utils import num_tokens_from_string
|
||||||
|
|
||||||
mongo = MongoClient(settings.MONGO_URI)
|
mongo = MongoClient(settings.MONGO_URI)
|
||||||
db = mongo["docsgpt"]
|
db = mongo["docsgpt"]
|
||||||
@@ -24,9 +24,9 @@ def update_token_usage(user_api_key, token_usage):
|
|||||||
def gen_token_usage(func):
|
def gen_token_usage(func):
|
||||||
def wrapper(self, model, messages, stream, **kwargs):
|
def wrapper(self, model, messages, stream, **kwargs):
|
||||||
for message in messages:
|
for message in messages:
|
||||||
self.token_usage["prompt_tokens"] += count_tokens(message["content"])
|
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
|
||||||
result = func(self, model, messages, stream, **kwargs)
|
result = func(self, model, messages, stream, **kwargs)
|
||||||
self.token_usage["generated_tokens"] += count_tokens(result)
|
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
|
||||||
update_token_usage(self.user_api_key, self.token_usage)
|
update_token_usage(self.user_api_key, self.token_usage)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -36,14 +36,14 @@ def gen_token_usage(func):
|
|||||||
def stream_token_usage(func):
|
def stream_token_usage(func):
|
||||||
def wrapper(self, model, messages, stream, **kwargs):
|
def wrapper(self, model, messages, stream, **kwargs):
|
||||||
for message in messages:
|
for message in messages:
|
||||||
self.token_usage["prompt_tokens"] += count_tokens(message["content"])
|
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
|
||||||
batch = []
|
batch = []
|
||||||
result = func(self, model, messages, stream, **kwargs)
|
result = func(self, model, messages, stream, **kwargs)
|
||||||
for r in result:
|
for r in result:
|
||||||
batch.append(r)
|
batch.append(r)
|
||||||
yield r
|
yield r
|
||||||
for line in batch:
|
for line in batch:
|
||||||
self.token_usage["generated_tokens"] += count_tokens(line)
|
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
|
||||||
update_token_usage(self.user_api_key, self.token_usage)
|
update_token_usage(self.user_api_key, self.token_usage)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
@@ -1,6 +1,22 @@
|
|||||||
from transformers import GPT2TokenizerFast
|
import tiktoken
|
||||||
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
|
_encoding = None
|
||||||
tokenizer.model_max_length = 100000
|
|
||||||
def count_tokens(string):
|
def get_encoding():
|
||||||
return len(tokenizer(string)['input_ids'])
|
global _encoding
|
||||||
|
if _encoding is None:
|
||||||
|
_encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
return _encoding
|
||||||
|
|
||||||
|
def num_tokens_from_string(string: str) -> int:
|
||||||
|
encoding = get_encoding()
|
||||||
|
num_tokens = len(encoding.encode(string))
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
|
def count_tokens_docs(docs):
|
||||||
|
docs_content = ""
|
||||||
|
for doc in docs:
|
||||||
|
docs_content += doc.page_content
|
||||||
|
|
||||||
|
tokens = num_tokens_from_string(docs_content)
|
||||||
|
return tokens
|
||||||
@@ -2,7 +2,6 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import string
|
import string
|
||||||
import zipfile
|
import zipfile
|
||||||
import tiktoken
|
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -14,6 +13,7 @@ from application.parser.remote.remote_creator import RemoteCreator
|
|||||||
from application.parser.open_ai_func import call_openai_api
|
from application.parser.open_ai_func import call_openai_api
|
||||||
from application.parser.schema.base import Document
|
from application.parser.schema.base import Document
|
||||||
from application.parser.token_func import group_split
|
from application.parser.token_func import group_split
|
||||||
|
from application.utils import count_tokens_docs
|
||||||
|
|
||||||
|
|
||||||
# Define a function to extract metadata from a given filename.
|
# Define a function to extract metadata from a given filename.
|
||||||
@@ -213,25 +213,3 @@ def remote_worker(self, source_data, name_job, user, loader, directory="temp"):
|
|||||||
shutil.rmtree(full_path)
|
shutil.rmtree(full_path)
|
||||||
|
|
||||||
return {"urls": source_data, "name_job": name_job, "user": user, "limited": False}
|
return {"urls": source_data, "name_job": name_job, "user": user, "limited": False}
|
||||||
|
|
||||||
|
|
||||||
def count_tokens_docs(docs):
|
|
||||||
# Here we convert the docs list to a string and calculate the number of tokens the string represents.
|
|
||||||
# docs_content = (" ".join(docs))
|
|
||||||
docs_content = ""
|
|
||||||
for doc in docs:
|
|
||||||
docs_content += doc.page_content
|
|
||||||
|
|
||||||
tokens, total_price = num_tokens_from_string(
|
|
||||||
string=docs_content, encoding_name="cl100k_base"
|
|
||||||
)
|
|
||||||
# Here we print the number of tokens and the approx user cost with some visually appealing formatting.
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
|
|
||||||
def num_tokens_from_string(string: str, encoding_name: str) -> int:
|
|
||||||
# Function to convert string to tokens and estimate user cost.
|
|
||||||
encoding = tiktoken.get_encoding(encoding_name)
|
|
||||||
num_tokens = len(encoding.encode(string))
|
|
||||||
total_price = (num_tokens / 1000) * 0.0004
|
|
||||||
return num_tokens, total_price
|
|
||||||
Reference in New Issue
Block a user