Merge pull request #965 from siiddhantt/feature/set-tokens-message-history

feat: dropdown to adjust conversational history limits
This commit is contained in:
Alex
2024-05-28 09:43:21 +01:00
committed by GitHub
11 changed files with 152 additions and 27 deletions

View File

@@ -78,7 +78,7 @@ def get_data_from_api_key(api_key):
if data is None:
return bad_request(401, "Invalid API key")
return data
def get_vectorstore(data):
if "active_docs" in data:
@@ -95,6 +95,7 @@ def get_vectorstore(data):
vectorstore = os.path.join("application", vectorstore)
return vectorstore
def is_azure_configured():
return (
settings.OPENAI_API_BASE
@@ -221,7 +222,10 @@ def stream():
chunks = int(data["chunks"])
else:
chunks = 2
if "token_limit" in data:
token_limit = data["token_limit"]
else:
token_limit = settings.DEFAULT_MAX_HISTORY
# check if active_docs or api_key is set
@@ -255,6 +259,7 @@ def stream():
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
@@ -291,6 +296,10 @@ def api_answer():
chunks = int(data["chunks"])
else:
chunks = 2
if "token_limit" in data:
token_limit = data["token_limit"]
else:
token_limit = settings.DEFAULT_MAX_HISTORY
# use try and except to check for exception
try:
@@ -314,7 +323,7 @@ def api_answer():
retriever_name = source["active_docs"]
prompt = get_prompt(prompt_id)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
@@ -322,6 +331,7 @@ def api_answer():
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
@@ -370,7 +380,6 @@ def api_search():
else:
source = {}
user_api_key = None
if (
source["active_docs"].split("/")[0] == "default"
@@ -379,6 +388,10 @@ def api_search():
retriever_name = "classic"
else:
retriever_name = source["active_docs"]
if "token_limit" in data:
token_limit = data["token_limit"]
else:
token_limit = settings.DEFAULT_MAX_HISTORY
retriever = RetrieverCreator.create_retriever(
retriever_name,
@@ -387,8 +400,9 @@ def api_search():
chat_history=[],
prompt="default",
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
docs = retriever.search()
return docs
return docs

View File

@@ -15,7 +15,8 @@ class Settings(BaseSettings):
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
TOKENS_MAX_HISTORY: int = 150
DEFAULT_MAX_HISTORY: int = 150
MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5}
UPLOAD_FOLDER: str = "inputs"
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant"
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search

View File

@@ -15,6 +15,7 @@ class BraveRetSearch(BaseRetriever):
chat_history,
prompt,
chunks=2,
token_limit=150,
gpt_model="docsgpt",
user_api_key=None,
):
@@ -24,6 +25,16 @@ class BraveRetSearch(BaseRetriever):
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
self.token_limit = (
token_limit
if token_limit
< settings.MODEL_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
else settings.MODEL_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
)
self.user_api_key = user_api_key
def _get_data(self):
@@ -70,10 +81,7 @@ class BraveRetSearch(BaseRetriever):
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
i["response"]
)
if (
tokens_current_history + tokens_batch
< settings.TOKENS_MAX_HISTORY
):
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
messages_combine.append(
{"role": "user", "content": i["prompt"]}

View File

@@ -16,6 +16,7 @@ class ClassicRAG(BaseRetriever):
chat_history,
prompt,
chunks=2,
token_limit=150,
gpt_model="docsgpt",
user_api_key=None,
):
@@ -25,6 +26,16 @@ class ClassicRAG(BaseRetriever):
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
self.token_limit = (
token_limit
if token_limit
< settings.MODEL_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
else settings.MODEL_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
)
self.user_api_key = user_api_key
def _get_vectorstore(self, source):
@@ -85,10 +96,7 @@ class ClassicRAG(BaseRetriever):
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
i["response"]
)
if (
tokens_current_history + tokens_batch
< settings.TOKENS_MAX_HISTORY
):
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
messages_combine.append(
{"role": "user", "content": i["prompt"]}

View File

@@ -15,6 +15,7 @@ class DuckDuckSearch(BaseRetriever):
chat_history,
prompt,
chunks=2,
token_limit=150,
gpt_model="docsgpt",
user_api_key=None,
):
@@ -24,6 +25,16 @@ class DuckDuckSearch(BaseRetriever):
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
self.token_limit = (
token_limit
if token_limit
< settings.MODEL_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
else settings.MODEL_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
)
self.user_api_key = user_api_key
def _parse_lang_string(self, input_string):
@@ -87,10 +98,7 @@ class DuckDuckSearch(BaseRetriever):
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
i["response"]
)
if (
tokens_current_history + tokens_batch
< settings.TOKENS_MAX_HISTORY
):
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
messages_combine.append(
{"role": "user", "content": i["prompt"]}