mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Merge pull request #1500 from ManishMadan2882/main
Limiting Conversational history
This commit is contained in:
@@ -18,7 +18,7 @@ from application.error import bad_request
|
||||
from application.extensions import api
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import check_required_fields
|
||||
from application.utils import check_required_fields, limit_chat_history
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -324,8 +324,7 @@ class Stream(Resource):
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = str(data.get("history", []))
|
||||
history = str(json.loads(history))
|
||||
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
|
||||
@@ -456,7 +455,7 @@ class Answer(Resource):
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = data.get("history", [])
|
||||
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = int(data.get("chunks", 2))
|
||||
|
||||
@@ -16,7 +16,7 @@ class Settings(BaseSettings):
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||
DEFAULT_MAX_HISTORY: int = 150
|
||||
MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5}
|
||||
MODEL_TOKEN_LIMITS: dict = {"gpt-4o-mini": 128000, "gpt-3.5-turbo": 4096, "claude-2": 1e5}
|
||||
UPLOAD_FOLDER: str = "inputs"
|
||||
PARSE_PDF_AS_IMAGE: bool = False
|
||||
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.utils import num_tokens_from_string
|
||||
from langchain_community.tools import BraveSearch
|
||||
|
||||
|
||||
@@ -73,15 +72,8 @@ class BraveRetSearch(BaseRetriever):
|
||||
yield {"source": doc}
|
||||
|
||||
if len(self.chat_history) > 1:
|
||||
tokens_current_history = 0
|
||||
# count tokens in history
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
|
||||
i["response"]
|
||||
)
|
||||
if tokens_current_history + tokens_batch < self.token_limit:
|
||||
tokens_current_history += tokens_batch
|
||||
messages_combine.append(
|
||||
{"role": "user", "content": i["prompt"]}
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@ from application.core.settings import settings
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
|
||||
class ClassicRAG(BaseRetriever):
|
||||
@@ -73,15 +72,8 @@ class ClassicRAG(BaseRetriever):
|
||||
yield {"source": doc}
|
||||
|
||||
if len(self.chat_history) > 1:
|
||||
tokens_current_history = 0
|
||||
# count tokens in history
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
|
||||
i["response"]
|
||||
)
|
||||
if tokens_current_history + tokens_batch < self.token_limit:
|
||||
tokens_current_history += tokens_batch
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append(
|
||||
{"role": "user", "content": i["prompt"]}
|
||||
)
|
||||
@@ -89,7 +81,7 @@ class ClassicRAG(BaseRetriever):
|
||||
{"role": "system", "content": i["response"]}
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.utils import num_tokens_from_string
|
||||
from langchain_community.tools import DuckDuckGoSearchResults
|
||||
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
|
||||
|
||||
@@ -89,16 +88,9 @@ class DuckDuckSearch(BaseRetriever):
|
||||
for doc in docs:
|
||||
yield {"source": doc}
|
||||
|
||||
if len(self.chat_history) > 1:
|
||||
tokens_current_history = 0
|
||||
# count tokens in history
|
||||
if len(self.chat_history) > 1:
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
|
||||
i["response"]
|
||||
)
|
||||
if tokens_current_history + tokens_batch < self.token_limit:
|
||||
tokens_current_history += tokens_batch
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append(
|
||||
{"role": "user", "content": i["prompt"]}
|
||||
)
|
||||
|
||||
@@ -46,3 +46,40 @@ def check_required_fields(data, required_fields):
|
||||
def get_hash(data):
|
||||
return hashlib.md5(data.encode()).hexdigest()
|
||||
|
||||
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
||||
"""
|
||||
Limits chat history based on token count.
|
||||
Returns a list of messages that fit within the token limit.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
|
||||
max_token_limit = (
|
||||
max_token_limit
|
||||
if max_token_limit and
|
||||
max_token_limit < settings.MODEL_TOKEN_LIMITS.get(
|
||||
gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
else settings.MODEL_TOKEN_LIMITS.get(
|
||||
gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if not history:
|
||||
return []
|
||||
|
||||
tokens_current_history = 0
|
||||
trimmed_history = []
|
||||
|
||||
for message in reversed(history):
|
||||
if "prompt" in message and "response" in message:
|
||||
tokens_batch = num_tokens_from_string(message["prompt"]) + num_tokens_from_string(
|
||||
message["response"]
|
||||
)
|
||||
if tokens_current_history + tokens_batch < max_token_limit:
|
||||
tokens_current_history += tokens_batch
|
||||
trimmed_history.insert(0, message)
|
||||
else:
|
||||
break
|
||||
|
||||
return trimmed_history
|
||||
|
||||
2
frontend/package-lock.json
generated
2
frontend/package-lock.json
generated
@@ -1649,7 +1649,7 @@
|
||||
"version": "18.3.0",
|
||||
"resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-18.3.0.tgz",
|
||||
"integrity": "sha512-EhwApuTmMBmXuFOikhQLIBUn6uFg81SwLMOAUgodJF14SOBOCMdU04gDoYi0WOJJHD144TL32z4yDqCW3dnkQg==",
|
||||
"devOptional": true,
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@types/react": "*"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user