Merge branch 'main' into tool-use

This commit is contained in:
Alex
2024-12-20 17:29:41 +00:00
committed by GitHub
84 changed files with 4317 additions and 5522 deletions

View File

@@ -18,7 +18,7 @@ from application.error import bad_request
from application.extensions import api
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import check_required_fields
from application.utils import check_required_fields, limit_chat_history
logger = logging.getLogger(__name__)
@@ -118,8 +118,31 @@ def is_azure_configured():
)
def save_conversation(conversation_id, question, response, source_log_docs, llm):
if conversation_id is not None and conversation_id != "None":
def save_conversation(conversation_id, question, response, source_log_docs, llm,index=None):
if conversation_id is not None and index is not None:
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{
"$set": {
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.sources": source_log_docs,
}
}
)
##remove following queries from the array
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{
"$push":{
"queries":{
"$each":[],
"$slice":index+1
}
}
}
)
elif conversation_id is not None and conversation_id != "None":
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
@@ -141,17 +164,17 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm)
"role": "assistant",
"content": "Summarise following conversation in no more than 3 "
"words, respond ONLY with the summary, use the same "
"language as the system \n\nUser: "
+ question
+ "\n\n"
+ "AI: "
+ response,
"language as the system",
},
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system",
"system \n\nUser: "
+ question
+ "\n\n"
+ "AI: "
+ response,
},
]
@@ -186,7 +209,7 @@ def get_prompt(prompt_id):
def complete_stream(
question, retriever, conversation_id, user_api_key, isNoneDoc=False
question, retriever, conversation_id, user_api_key, isNoneDoc=False,index=None
):
try:
@@ -217,7 +240,7 @@ def complete_stream(
)
if user_api_key is None:
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
conversation_id, question, response_full, source_log_docs, llm,index
)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
@@ -282,6 +305,9 @@ class Stream(Resource):
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"index":fields.Integer(
required=False, description="The position where query is to be updated"
),
},
)
@@ -290,23 +316,23 @@ class Stream(Resource):
def post(self):
data = request.get_json()
required_fields = ["question"]
if "index" in data:
required_fields = ["question","conversation_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = data.get("history", [])
history = json.loads(history)
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
index=data.get("index",None)
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
@@ -343,7 +369,7 @@ class Stream(Resource):
gpt_model=gpt_model,
user_api_key=user_api_key,
)
return Response(
complete_stream(
question=question,
@@ -351,6 +377,7 @@ class Stream(Resource):
conversation_id=conversation_id,
user_api_key=user_api_key,
isNoneDoc=data.get("isNoneDoc"),
index=index,
),
mimetype="text/event-stream",
)
@@ -428,7 +455,7 @@ class Answer(Resource):
try:
question = data["question"]
history = data.get("history", [])
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
chunks = int(data.get("chunks", 2))