diff --git a/application/api/user/attachments/routes.py b/application/api/user/attachments/routes.py index a8eae09f..87af6a7e 100644 --- a/application/api/user/attachments/routes.py +++ b/application/api/user/attachments/routes.py @@ -10,7 +10,7 @@ from application.api import api from application.api.user.base import agents_collection, storage from application.api.user.tasks import store_attachment from application.core.settings import settings -from application.tts.google_tts import GoogleTTS +from application.tts.tts_creator import TTSCreator from application.utils import safe_filename @@ -133,7 +133,7 @@ class TextToSpeech(Resource): data = request.get_json() text = data["text"] try: - tts_instance = GoogleTTS() + tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER) audio_base64, detected_language = tts_instance.text_to_speech(text) return make_response( jsonify( diff --git a/application/core/settings.py b/application/core/settings.py index 0be38275..3871d09b 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -130,6 +130,7 @@ class Settings(BaseSettings): # Encryption settings ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key" + TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs ELEVENLABS_API_KEY: Optional[str] = None path = Path(__file__).parent.parent.absolute() diff --git a/application/requirements.txt b/application/requirements.txt index 3882bd6d..08d259b1 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -10,6 +10,7 @@ ebooklib==0.18 escodegen==1.0.11 esprima==4.0.1 esutils==1.0.1 +elevenlabs==2.17.0 Flask==3.1.1 faiss-cpu==1.9.0.post1 fastmcp==2.11.0 diff --git a/application/tts/elevenlabs.py b/application/tts/elevenlabs.py index 0d82021e..c1927c6f 100644 --- a/application/tts/elevenlabs.py +++ b/application/tts/elevenlabs.py @@ -15,10 +15,11 @@ class ElevenlabsTTS(BaseTTS): def text_to_speech(self, text): lang = "en" - audio = self.client.generate( + audio = self.client.text_to_speech.convert( + voice_id="nPczCjzI2devNBz1zQrb", + model_id="eleven_multilingual_v2", text=text, - model="eleven_multilingual_v2", - voice="Brian", + output_format="mp3_44100_128" ) audio_data = BytesIO() for chunk in audio: diff --git a/application/tts/tts_creator.py b/application/tts/tts_creator.py new file mode 100644 index 00000000..28d9f51b --- /dev/null +++ b/application/tts/tts_creator.py @@ -0,0 +1,18 @@ +from application.tts.google_tts import GoogleTTS +from application.tts.elevenlabs import ElevenlabsTTS +from application.tts.base import BaseTTS + + + +class TTSCreator: + tts_providers = { + "google_tts": GoogleTTS, + "elevenlabs": ElevenlabsTTS, + } + + @classmethod + def create_tts(cls, tts_type, *args, **kwargs)-> BaseTTS: + tts_class = cls.tts_providers.get(tts_type.lower()) + if not tts_class: + raise ValueError(f"No tts class found for type {tts_type}") + return tts_class(*args, **kwargs) \ No newline at end of file diff --git a/tests/tts/test_elevenlabs_tts.py b/tests/tts/test_elevenlabs_tts.py index bea61b80..786b0f41 100644 --- a/tests/tts/test_elevenlabs_tts.py +++ b/tests/tts/test_elevenlabs_tts.py @@ -16,12 +16,25 @@ def test_elevenlabs_text_to_speech_monkeypatched_client(monkeypatch): class DummyClient: def __init__(self, api_key): created["api_key"] = api_key - self.generate_calls = [] + self.convert_calls = [] - def generate(self, *, text, model, voice): - self.generate_calls.append({"text": text, "model": model, "voice": voice}) - yield b"chunk-one" - yield b"chunk-two" + class TextToSpeech: + def __init__(self, outer): + self._outer = outer + + def convert(self, *, voice_id, model_id, text, output_format): + self._outer.convert_calls.append( + { + "voice_id": voice_id, + "model_id": model_id, + "text": text, + "output_format": output_format, + } + ) + yield b"chunk-one" + yield b"chunk-two" + + self.text_to_speech = TextToSpeech(self) client_module = ModuleType("elevenlabs.client") client_module.ElevenLabs = DummyClient @@ -35,8 +48,13 @@ def test_elevenlabs_text_to_speech_monkeypatched_client(monkeypatch): audio_base64, lang = tts.text_to_speech("Speak") assert created["api_key"] == "api-key" - assert tts.client.generate_calls == [ - {"text": "Speak", "model": "eleven_multilingual_v2", "voice": "Brian"} + assert tts.client.convert_calls == [ + { + "voice_id": "nPczCjzI2devNBz1zQrb", + "model_id": "eleven_multilingual_v2", + "text": "Speak", + "output_format": "mp3_44100_128", + } ] assert lang == "en" assert base64.b64decode(audio_base64.encode()) == b"chunk-onechunk-two" diff --git a/tests/tts/test_tts_creator.py b/tests/tts/test_tts_creator.py new file mode 100644 index 00000000..526008fc --- /dev/null +++ b/tests/tts/test_tts_creator.py @@ -0,0 +1,61 @@ +import pytest +from unittest.mock import patch, MagicMock +from application.tts.tts_creator import TTSCreator + + +@pytest.fixture +def tts_creator(): + return TTSCreator() + + +def test_create_google_tts(tts_creator): + # Patch the provider registry so the factory calls our mock class + with patch.dict(TTSCreator.tts_providers, {"google_tts": MagicMock()}): + mock_google_tts = TTSCreator.tts_providers["google_tts"] + instance = MagicMock() + mock_google_tts.return_value = instance + + result = tts_creator.create_tts("google_tts", "arg1", key="value") + + mock_google_tts.assert_called_once_with("arg1", key="value") + assert result == instance + + +def test_create_elevenlabs_tts(tts_creator): + # Patch the provider registry so the factory calls our mock class + with patch.dict(TTSCreator.tts_providers, {"elevenlabs": MagicMock()}): + mock_elevenlabs_tts = TTSCreator.tts_providers["elevenlabs"] + instance = MagicMock() + mock_elevenlabs_tts.return_value = instance + + result = tts_creator.create_tts("elevenlabs", "voice", lang="en") + + mock_elevenlabs_tts.assert_called_once_with("voice", lang="en") + assert result == instance + + +def test_invalid_tts_type(tts_creator): + with pytest.raises(ValueError) as excinfo: + tts_creator.create_tts("unknown_tts") + assert "No tts class found" in str(excinfo.value) + + +def test_tts_type_case_insensitivity(tts_creator): + # Patch the provider registry to ensure case-insensitive lookup hits our mock + with patch.dict(TTSCreator.tts_providers, {"google_tts": MagicMock()}): + mock_google_tts = TTSCreator.tts_providers["google_tts"] + instance = MagicMock() + mock_google_tts.return_value = instance + + result = tts_creator.create_tts("GoOgLe_TtS") + + mock_google_tts.assert_called_once_with() + assert result == instance + + +def test_tts_providers_integrity(tts_creator): + providers = tts_creator.tts_providers + assert "google_tts" in providers + assert "elevenlabs" in providers + assert callable(providers["google_tts"]) + assert callable(providers["elevenlabs"]) \ No newline at end of file