diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 17eb5cc3..e05de123 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -269,9 +269,6 @@ class Stream(Resource): "prompt_id": fields.String( required=False, default="default", description="Prompt ID" ), - "selectedDocs": fields.String( - required=False, description="Selected documents" - ), "chunks": fields.Integer( required=False, default=2, description="Number of chunks" ), @@ -303,10 +300,9 @@ class Stream(Resource): history = json.loads(history) conversation_id = data.get("conversation_id") prompt_id = data.get("prompt_id", "default") - if "selectedDocs" in data and data["selectedDocs"] is None: - chunks = 0 - else: - chunks = int(data.get("chunks", 2)) + + + chunks = int(data.get("chunks", 2)) token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY) retriever_name = data.get("retriever", "classic") @@ -333,7 +329,8 @@ class Stream(Resource): ) prompt = get_prompt(prompt_id) - + if "isNoneDoc" in data and data["isNoneDoc"] is True: + chunks = 0 retriever = RetrieverCreator.create_retriever( retriever_name, question=question, diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 3f1a7218..794c69d4 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -17,6 +17,7 @@ from application.core.settings import settings from application.extensions import api from application.utils import check_required_fields from application.vectorstore.vector_creator import VectorCreator +from application.tts.google_tts import GoogleTTS mongo = MongoClient(settings.MONGO_URI) db = mongo["docsgpt"] @@ -1663,3 +1664,27 @@ class ManageSync(Resource): return make_response(jsonify({"success": False, "error": str(err)}), 400) return make_response(jsonify({"success": True}), 200) + + +@user_ns.route("/api/tts") +class TextToSpeech(Resource): + tts_model = api.model( + "TextToSpeechModel", + { + "text": fields.String(required=True, description="Text to be synthesized as audio"), + }, + ) + + @api.expect(tts_model) + @api.doc(description="Synthesize audio speech from text") + def post(self): + data = request.get_json() + text = data["text"] + try: + tts_instance = GoogleTTS(text) + audio_base64, detected_language = tts_instance.text_to_speech() + return make_response(jsonify({"success": True,'audio_base64': audio_base64,'lang':detected_language}), 200) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + diff --git a/application/requirements.txt b/application/requirements.txt index 6ea1d1ba..aad629f1 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -85,3 +85,4 @@ vine==5.1.0 wcwidth==0.2.13 werkzeug==3.0.4 yarl==1.11.1 +gTTS==2.3.2 \ No newline at end of file diff --git a/application/tts/base.py b/application/tts/base.py new file mode 100644 index 00000000..143bed73 --- /dev/null +++ b/application/tts/base.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod + + +class BaseTTS(ABC): + def __init__(self): + pass + + @abstractmethod + def text_to_speech(self, *args, **kwargs): + pass \ No newline at end of file diff --git a/application/tts/google_tts.py b/application/tts/google_tts.py new file mode 100644 index 00000000..310309dc --- /dev/null +++ b/application/tts/google_tts.py @@ -0,0 +1,19 @@ +import io +import base64 +from gtts import gTTS +from application.tts.base import BaseTTS + + +class GoogleTTS(BaseTTS): + def __init__(self, text): + self.text = text + + + def text_to_speech(self): + lang = "en" + audio_fp = io.BytesIO() + tts = gTTS(text=self.text, lang=lang, slow=False) + tts.write_to_fp(audio_fp) + audio_fp.seek(0) + audio_base64 = base64.b64encode(audio_fp.read()).decode("utf-8") + return audio_base64, lang