mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-12-02 10:03:15 +00:00
Merge pull request #965 from siiddhantt/feature/set-tokens-message-history
feat: dropdown to adjust conversational history limits
This commit is contained in:
@@ -78,7 +78,7 @@ def get_data_from_api_key(api_key):
|
||||
if data is None:
|
||||
return bad_request(401, "Invalid API key")
|
||||
return data
|
||||
|
||||
|
||||
|
||||
def get_vectorstore(data):
|
||||
if "active_docs" in data:
|
||||
@@ -95,6 +95,7 @@ def get_vectorstore(data):
|
||||
vectorstore = os.path.join("application", vectorstore)
|
||||
return vectorstore
|
||||
|
||||
|
||||
def is_azure_configured():
|
||||
return (
|
||||
settings.OPENAI_API_BASE
|
||||
@@ -221,7 +222,10 @@ def stream():
|
||||
chunks = int(data["chunks"])
|
||||
else:
|
||||
chunks = 2
|
||||
|
||||
if "token_limit" in data:
|
||||
token_limit = data["token_limit"]
|
||||
else:
|
||||
token_limit = settings.DEFAULT_MAX_HISTORY
|
||||
|
||||
# check if active_docs or api_key is set
|
||||
|
||||
@@ -255,6 +259,7 @@ def stream():
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
@@ -291,6 +296,10 @@ def api_answer():
|
||||
chunks = int(data["chunks"])
|
||||
else:
|
||||
chunks = 2
|
||||
if "token_limit" in data:
|
||||
token_limit = data["token_limit"]
|
||||
else:
|
||||
token_limit = settings.DEFAULT_MAX_HISTORY
|
||||
|
||||
# use try and except to check for exception
|
||||
try:
|
||||
@@ -314,7 +323,7 @@ def api_answer():
|
||||
retriever_name = source["active_docs"]
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
@@ -322,6 +331,7 @@ def api_answer():
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
@@ -370,7 +380,6 @@ def api_search():
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
|
||||
if (
|
||||
source["active_docs"].split("/")[0] == "default"
|
||||
@@ -379,6 +388,10 @@ def api_search():
|
||||
retriever_name = "classic"
|
||||
else:
|
||||
retriever_name = source["active_docs"]
|
||||
if "token_limit" in data:
|
||||
token_limit = data["token_limit"]
|
||||
else:
|
||||
token_limit = settings.DEFAULT_MAX_HISTORY
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
@@ -387,8 +400,9 @@ def api_search():
|
||||
chat_history=[],
|
||||
prompt="default",
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
docs = retriever.search()
|
||||
return docs
|
||||
return docs
|
||||
|
||||
@@ -15,7 +15,8 @@ class Settings(BaseSettings):
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||
TOKENS_MAX_HISTORY: int = 150
|
||||
DEFAULT_MAX_HISTORY: int = 150
|
||||
MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5}
|
||||
UPLOAD_FOLDER: str = "inputs"
|
||||
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant"
|
||||
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
|
||||
|
||||
@@ -15,6 +15,7 @@ class BraveRetSearch(BaseRetriever):
|
||||
chat_history,
|
||||
prompt,
|
||||
chunks=2,
|
||||
token_limit=150,
|
||||
gpt_model="docsgpt",
|
||||
user_api_key=None,
|
||||
):
|
||||
@@ -24,6 +25,16 @@ class BraveRetSearch(BaseRetriever):
|
||||
self.prompt = prompt
|
||||
self.chunks = chunks
|
||||
self.gpt_model = gpt_model
|
||||
self.token_limit = (
|
||||
token_limit
|
||||
if token_limit
|
||||
< settings.MODEL_TOKEN_LIMITS.get(
|
||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
else settings.MODEL_TOKEN_LIMITS.get(
|
||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
)
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _get_data(self):
|
||||
@@ -70,10 +81,7 @@ class BraveRetSearch(BaseRetriever):
|
||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
|
||||
i["response"]
|
||||
)
|
||||
if (
|
||||
tokens_current_history + tokens_batch
|
||||
< settings.TOKENS_MAX_HISTORY
|
||||
):
|
||||
if tokens_current_history + tokens_batch < self.token_limit:
|
||||
tokens_current_history += tokens_batch
|
||||
messages_combine.append(
|
||||
{"role": "user", "content": i["prompt"]}
|
||||
|
||||
@@ -16,6 +16,7 @@ class ClassicRAG(BaseRetriever):
|
||||
chat_history,
|
||||
prompt,
|
||||
chunks=2,
|
||||
token_limit=150,
|
||||
gpt_model="docsgpt",
|
||||
user_api_key=None,
|
||||
):
|
||||
@@ -25,6 +26,16 @@ class ClassicRAG(BaseRetriever):
|
||||
self.prompt = prompt
|
||||
self.chunks = chunks
|
||||
self.gpt_model = gpt_model
|
||||
self.token_limit = (
|
||||
token_limit
|
||||
if token_limit
|
||||
< settings.MODEL_TOKEN_LIMITS.get(
|
||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
else settings.MODEL_TOKEN_LIMITS.get(
|
||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
)
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _get_vectorstore(self, source):
|
||||
@@ -85,10 +96,7 @@ class ClassicRAG(BaseRetriever):
|
||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
|
||||
i["response"]
|
||||
)
|
||||
if (
|
||||
tokens_current_history + tokens_batch
|
||||
< settings.TOKENS_MAX_HISTORY
|
||||
):
|
||||
if tokens_current_history + tokens_batch < self.token_limit:
|
||||
tokens_current_history += tokens_batch
|
||||
messages_combine.append(
|
||||
{"role": "user", "content": i["prompt"]}
|
||||
|
||||
@@ -15,6 +15,7 @@ class DuckDuckSearch(BaseRetriever):
|
||||
chat_history,
|
||||
prompt,
|
||||
chunks=2,
|
||||
token_limit=150,
|
||||
gpt_model="docsgpt",
|
||||
user_api_key=None,
|
||||
):
|
||||
@@ -24,6 +25,16 @@ class DuckDuckSearch(BaseRetriever):
|
||||
self.prompt = prompt
|
||||
self.chunks = chunks
|
||||
self.gpt_model = gpt_model
|
||||
self.token_limit = (
|
||||
token_limit
|
||||
if token_limit
|
||||
< settings.MODEL_TOKEN_LIMITS.get(
|
||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
else settings.MODEL_TOKEN_LIMITS.get(
|
||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
)
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _parse_lang_string(self, input_string):
|
||||
@@ -87,10 +98,7 @@ class DuckDuckSearch(BaseRetriever):
|
||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
|
||||
i["response"]
|
||||
)
|
||||
if (
|
||||
tokens_current_history + tokens_batch
|
||||
< settings.TOKENS_MAX_HISTORY
|
||||
):
|
||||
if tokens_current_history + tokens_batch < self.token_limit:
|
||||
tokens_current_history += tokens_batch
|
||||
messages_combine.append(
|
||||
{"role": "user", "content": i["prompt"]}
|
||||
|
||||
Reference in New Issue
Block a user