From 4f88b6dc71148a51928ca9975ada75fc315ac900 Mon Sep 17 00:00:00 2001
From: Alex
Date: Sat, 31 Aug 2024 12:30:03 +0100
Subject: [PATCH 1/3] feat: logging
---
application/api/answer/routes.py | 25 +++++++++++++++++++------
application/app.py | 2 ++
application/celery_init.py | 6 ++++++
application/core/logging_config.py | 22 ++++++++++++++++++++++
application/worker.py | 12 ++++++------
5 files changed, 55 insertions(+), 12 deletions(-)
create mode 100644 application/core/logging_config.py
diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py
index 893edd3a..a809b4ef 100644
--- a/application/api/answer/routes.py
+++ b/application/api/answer/routes.py
@@ -1,7 +1,7 @@
import asyncio
import os
import sys
-from flask import Blueprint, request, Response
+from flask import Blueprint, request, Response, current_app
import json
import datetime
import logging
@@ -267,6 +267,10 @@ def stream():
else:
retriever_name = source["active_docs"]
+ current_app.logger.info(f"/stream - request_data: {data}, source: {source}",
+ extra={"data": json.dumps({"request_data": data, "source": source})}
+ )
+
prompt = get_prompt(prompt_id)
retriever = RetrieverCreator.create_retriever(
@@ -301,7 +305,9 @@ def stream():
mimetype="text/event-stream",
)
except Exception as e:
- print("\033[91merr", str(e), file=sys.stderr)
+ current_app.logger.error(f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
+ extra={"error": str(e), "traceback": traceback.format_exc()}
+ )
message = e.args[0]
status_code = 400
# # Custom exceptions with two arguments, index 1 as status code
@@ -345,7 +351,6 @@ def api_answer():
else:
token_limit = settings.DEFAULT_MAX_HISTORY
- # use try and except to check for exception
try:
# check if the vectorstore is set
if "api_key" in data:
@@ -365,6 +370,10 @@ def api_answer():
prompt = get_prompt(prompt_id)
+ current_app.logger.info(f"/api/answer - request_data: {data}, source: {source}",
+ extra={"data": json.dumps({"request_data": data, "source": source})}
+ )
+
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
@@ -399,9 +408,9 @@ def api_answer():
return result
except Exception as e:
- # print whole traceback
- traceback.print_exc()
- print(str(e))
+ current_app.logger.error(f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
+ extra={"error": str(e), "traceback": traceback.format_exc()}
+ )
return bad_request(500, str(e))
@@ -433,6 +442,10 @@ def api_search():
token_limit = data["token_limit"]
else:
token_limit = settings.DEFAULT_MAX_HISTORY
+
+ current_app.logger.info(f"/api/answer - request_data: {data}, source: {source}",
+ extra={"data": json.dumps({"request_data": data, "source": source})}
+ )
retriever = RetrieverCreator.create_retriever(
retriever_name,
diff --git a/application/app.py b/application/app.py
index fe8efd12..87d9d42f 100644
--- a/application/app.py
+++ b/application/app.py
@@ -6,12 +6,14 @@ from application.core.settings import settings
from application.api.user.routes import user
from application.api.answer.routes import answer
from application.api.internal.routes import internal
+from application.core.logging_config import setup_logging
if platform.system() == "Windows":
import pathlib
pathlib.PosixPath = pathlib.WindowsPath
dotenv.load_dotenv()
+setup_logging()
app = Flask(__name__)
app.register_blueprint(user)
diff --git a/application/celery_init.py b/application/celery_init.py
index c19c2e75..c5838083 100644
--- a/application/celery_init.py
+++ b/application/celery_init.py
@@ -1,9 +1,15 @@
from celery import Celery
from application.core.settings import settings
+from celery.signals import setup_logging
def make_celery(app_name=__name__):
celery = Celery(app_name, broker=settings.CELERY_BROKER_URL, backend=settings.CELERY_RESULT_BACKEND)
celery.conf.update(settings)
return celery
+@setup_logging.connect
+def config_loggers(*args, **kwargs):
+ from application.core.logging_config import setup_logging
+ setup_logging()
+
celery = make_celery()
diff --git a/application/core/logging_config.py b/application/core/logging_config.py
new file mode 100644
index 00000000..e693cb91
--- /dev/null
+++ b/application/core/logging_config.py
@@ -0,0 +1,22 @@
+from logging.config import dictConfig
+
+def setup_logging():
+ dictConfig({
+ 'version': 1,
+ 'formatters': {
+ 'default': {
+ 'format': '[%(asctime)s] %(levelname)s in %(module)s: %(message)s',
+ }
+ },
+ "handlers": {
+ "console": {
+ "class": "logging.StreamHandler",
+ "stream": "ext://sys.stdout",
+ "formatter": "default",
+ }
+ },
+ 'root': {
+ 'level': 'INFO',
+ 'handlers': ['console'],
+ },
+ })
\ No newline at end of file
diff --git a/application/worker.py b/application/worker.py
index bd1bc15a..3105aabe 100755
--- a/application/worker.py
+++ b/application/worker.py
@@ -4,6 +4,7 @@ import string
import zipfile
import tiktoken
from urllib.parse import urljoin
+import logging
import requests
@@ -14,6 +15,7 @@ from application.parser.open_ai_func import call_openai_api
from application.parser.schema.base import Document
from application.parser.token_func import group_split
+
# Define a function to extract metadata from a given filename.
def metadata_from_filename(title):
store = "/".join(title.split("/")[1:3])
@@ -41,7 +43,7 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
max_depth (int): Maximum allowed depth of recursion to prevent infinite loops.
"""
if current_depth > max_depth:
- print(f"Reached maximum recursion depth of {max_depth}")
+ logging.warning(f"Reached maximum recursion depth of {max_depth}")
return
with zipfile.ZipFile(zip_path, "r") as zip_ref:
@@ -88,16 +90,13 @@ def ingest_worker(self, directory, formats, name_job, filename, user):
max_tokens = 1250
recursion_depth = 2
full_path = os.path.join(directory, user, name_job)
- import sys
- print(full_path, file=sys.stderr)
+ logging.info(f"Ingest file: {full_path}", extra={"user": user, "job": name_job})
# check if API_URL env variable is set
file_data = {"name": name_job, "file": filename, "user": user}
response = requests.get(
urljoin(settings.API_URL, "/api/download"), params=file_data
)
- # check if file is in the response
- print(response, file=sys.stderr)
file = response.content
if not os.path.exists(full_path):
@@ -137,7 +136,7 @@ def ingest_worker(self, directory, formats, name_job, filename, user):
if sample:
for i in range(min(5, len(raw_docs))):
- print(raw_docs[i].text)
+ logging.info(f"Sample document {i}: {raw_docs[i]}")
# get files from outputs/inputs/index.faiss and outputs/inputs/index.pkl
# and send them to the server (provide user and name in form)
@@ -180,6 +179,7 @@ def remote_worker(self, source_data, name_job, user, loader, directory="temp"):
if not os.path.exists(full_path):
os.makedirs(full_path)
self.update_state(state="PROGRESS", meta={"current": 1})
+ logging.info(f"Remote job: {full_path}", extra={"user": user, "job": name_job, source_data: source_data})
remote_loader = RemoteCreator.create_loader(loader)
raw_docs = remote_loader.load_data(source_data)
From c49b7613e03174a56cb5da76f643bd04dada627e Mon Sep 17 00:00:00 2001
From: Alex
Date: Sat, 31 Aug 2024 12:53:37 +0100
Subject: [PATCH 2/3] fix: langchain warning
---
application/parser/remote/crawler_loader.py | 2 +-
application/parser/remote/sitemap_loader.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/application/parser/remote/crawler_loader.py b/application/parser/remote/crawler_loader.py
index 2a63f284..76325ae6 100644
--- a/application/parser/remote/crawler_loader.py
+++ b/application/parser/remote/crawler_loader.py
@@ -5,7 +5,7 @@ from application.parser.remote.base import BaseRemote
class CrawlerLoader(BaseRemote):
def __init__(self, limit=10):
- from langchain.document_loaders import WebBaseLoader
+ from langchain_community.document_loaders import WebBaseLoader
self.loader = WebBaseLoader # Initialize the document loader
self.limit = limit # Set the limit for the number of pages to scrape
diff --git a/application/parser/remote/sitemap_loader.py b/application/parser/remote/sitemap_loader.py
index 6e9182c4..8066f4f6 100644
--- a/application/parser/remote/sitemap_loader.py
+++ b/application/parser/remote/sitemap_loader.py
@@ -5,7 +5,7 @@ from application.parser.remote.base import BaseRemote
class SitemapLoader(BaseRemote):
def __init__(self, limit=20):
- from langchain.document_loaders import WebBaseLoader
+ from langchain_community.document_loaders import WebBaseLoader
self.loader = WebBaseLoader
self.limit = limit # Adding limit to control the number of URLs to process
From d9309ebc6eaee86bc5bc0f41a83d6e67f12990a6 Mon Sep 17 00:00:00 2001
From: Alex
Date: Sat, 31 Aug 2024 17:07:40 +0100
Subject: [PATCH 3/3] feat: better token counter
---
application/retriever/brave_search.py | 4 ++--
application/retriever/classic_rag.py | 4 ++--
application/retriever/duckduck_search.py | 4 ++--
application/usage.py | 10 ++++-----
application/utils.py | 26 +++++++++++++++++++-----
application/worker.py | 26 ++----------------------
6 files changed, 34 insertions(+), 40 deletions(-)
diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py
index 70dbbf20..5d1e1566 100644
--- a/application/retriever/brave_search.py
+++ b/application/retriever/brave_search.py
@@ -2,7 +2,7 @@ import json
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
-from application.utils import count_tokens
+from application.utils import num_tokens_from_string
from langchain_community.tools import BraveSearch
@@ -78,7 +78,7 @@ class BraveRetSearch(BaseRetriever):
self.chat_history.reverse()
for i in self.chat_history:
if "prompt" in i and "response" in i:
- tokens_batch = count_tokens(i["prompt"]) + count_tokens(
+ tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:
diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py
index 2b77db34..aef6e503 100644
--- a/application/retriever/classic_rag.py
+++ b/application/retriever/classic_rag.py
@@ -4,7 +4,7 @@ from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator
-from application.utils import count_tokens
+from application.utils import num_tokens_from_string
class ClassicRAG(BaseRetriever):
@@ -98,7 +98,7 @@ class ClassicRAG(BaseRetriever):
self.chat_history.reverse()
for i in self.chat_history:
if "prompt" in i and "response" in i:
- tokens_batch = count_tokens(i["prompt"]) + count_tokens(
+ tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:
diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py
index bee74e24..6d2965f5 100644
--- a/application/retriever/duckduck_search.py
+++ b/application/retriever/duckduck_search.py
@@ -1,7 +1,7 @@
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
-from application.utils import count_tokens
+from application.utils import num_tokens_from_string
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
@@ -95,7 +95,7 @@ class DuckDuckSearch(BaseRetriever):
self.chat_history.reverse()
for i in self.chat_history:
if "prompt" in i and "response" in i:
- tokens_batch = count_tokens(i["prompt"]) + count_tokens(
+ tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:
diff --git a/application/usage.py b/application/usage.py
index 1b26e9d7..aba0ec77 100644
--- a/application/usage.py
+++ b/application/usage.py
@@ -2,7 +2,7 @@ import sys
from pymongo import MongoClient
from datetime import datetime
from application.core.settings import settings
-from application.utils import count_tokens
+from application.utils import num_tokens_from_string
mongo = MongoClient(settings.MONGO_URI)
db = mongo["docsgpt"]
@@ -24,9 +24,9 @@ def update_token_usage(user_api_key, token_usage):
def gen_token_usage(func):
def wrapper(self, model, messages, stream, **kwargs):
for message in messages:
- self.token_usage["prompt_tokens"] += count_tokens(message["content"])
+ self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
result = func(self, model, messages, stream, **kwargs)
- self.token_usage["generated_tokens"] += count_tokens(result)
+ self.token_usage["generated_tokens"] += num_tokens_from_string(result)
update_token_usage(self.user_api_key, self.token_usage)
return result
@@ -36,14 +36,14 @@ def gen_token_usage(func):
def stream_token_usage(func):
def wrapper(self, model, messages, stream, **kwargs):
for message in messages:
- self.token_usage["prompt_tokens"] += count_tokens(message["content"])
+ self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
batch = []
result = func(self, model, messages, stream, **kwargs)
for r in result:
batch.append(r)
yield r
for line in batch:
- self.token_usage["generated_tokens"] += count_tokens(line)
+ self.token_usage["generated_tokens"] += num_tokens_from_string(line)
update_token_usage(self.user_api_key, self.token_usage)
return wrapper
diff --git a/application/utils.py b/application/utils.py
index 3d9bf520..70a00ce0 100644
--- a/application/utils.py
+++ b/application/utils.py
@@ -1,6 +1,22 @@
-from transformers import GPT2TokenizerFast
+import tiktoken
-tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
-tokenizer.model_max_length = 100000
-def count_tokens(string):
- return len(tokenizer(string)['input_ids'])
\ No newline at end of file
+_encoding = None
+
+def get_encoding():
+ global _encoding
+ if _encoding is None:
+ _encoding = tiktoken.get_encoding("cl100k_base")
+ return _encoding
+
+def num_tokens_from_string(string: str) -> int:
+ encoding = get_encoding()
+ num_tokens = len(encoding.encode(string))
+ return num_tokens
+
+def count_tokens_docs(docs):
+ docs_content = ""
+ for doc in docs:
+ docs_content += doc.page_content
+
+ tokens = num_tokens_from_string(docs_content)
+ return tokens
\ No newline at end of file
diff --git a/application/worker.py b/application/worker.py
index 3105aabe..c315f916 100755
--- a/application/worker.py
+++ b/application/worker.py
@@ -2,7 +2,6 @@ import os
import shutil
import string
import zipfile
-import tiktoken
from urllib.parse import urljoin
import logging
@@ -14,6 +13,7 @@ from application.parser.remote.remote_creator import RemoteCreator
from application.parser.open_ai_func import call_openai_api
from application.parser.schema.base import Document
from application.parser.token_func import group_split
+from application.utils import count_tokens_docs
# Define a function to extract metadata from a given filename.
@@ -212,26 +212,4 @@ def remote_worker(self, source_data, name_job, user, loader, directory="temp"):
shutil.rmtree(full_path)
- return {"urls": source_data, "name_job": name_job, "user": user, "limited": False}
-
-
-def count_tokens_docs(docs):
- # Here we convert the docs list to a string and calculate the number of tokens the string represents.
- # docs_content = (" ".join(docs))
- docs_content = ""
- for doc in docs:
- docs_content += doc.page_content
-
- tokens, total_price = num_tokens_from_string(
- string=docs_content, encoding_name="cl100k_base"
- )
- # Here we print the number of tokens and the approx user cost with some visually appealing formatting.
- return tokens
-
-
-def num_tokens_from_string(string: str, encoding_name: str) -> int:
- # Function to convert string to tokens and estimate user cost.
- encoding = tiktoken.get_encoding(encoding_name)
- num_tokens = len(encoding.encode(string))
- total_price = (num_tokens / 1000) * 0.0004
- return num_tokens, total_price
\ No newline at end of file
+ return {"urls": source_data, "name_job": name_job, "user": user, "limited": False}
\ No newline at end of file