add configurable provider in settings.py and update ElevenLabs Api (#2065) (#2074)

This commit is contained in:
Nihar
2025-10-22 21:37:21 +05:30
committed by GitHub
parent c4e8daf50e
commit f448e4a615
7 changed files with 112 additions and 12 deletions

View File

@@ -10,7 +10,7 @@ from application.api import api
from application.api.user.base import agents_collection, storage from application.api.user.base import agents_collection, storage
from application.api.user.tasks import store_attachment from application.api.user.tasks import store_attachment
from application.core.settings import settings from application.core.settings import settings
from application.tts.google_tts import GoogleTTS from application.tts.tts_creator import TTSCreator
from application.utils import safe_filename from application.utils import safe_filename
@@ -133,7 +133,7 @@ class TextToSpeech(Resource):
data = request.get_json() data = request.get_json()
text = data["text"] text = data["text"]
try: try:
tts_instance = GoogleTTS() tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER)
audio_base64, detected_language = tts_instance.text_to_speech(text) audio_base64, detected_language = tts_instance.text_to_speech(text)
return make_response( return make_response(
jsonify( jsonify(

View File

@@ -130,6 +130,7 @@ class Settings(BaseSettings):
# Encryption settings # Encryption settings
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key" ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
ELEVENLABS_API_KEY: Optional[str] = None ELEVENLABS_API_KEY: Optional[str] = None
path = Path(__file__).parent.parent.absolute() path = Path(__file__).parent.parent.absolute()

View File

@@ -10,6 +10,7 @@ ebooklib==0.18
escodegen==1.0.11 escodegen==1.0.11
esprima==4.0.1 esprima==4.0.1
esutils==1.0.1 esutils==1.0.1
elevenlabs==2.17.0
Flask==3.1.1 Flask==3.1.1
faiss-cpu==1.9.0.post1 faiss-cpu==1.9.0.post1
fastmcp==2.11.0 fastmcp==2.11.0

View File

@@ -15,10 +15,11 @@ class ElevenlabsTTS(BaseTTS):
def text_to_speech(self, text): def text_to_speech(self, text):
lang = "en" lang = "en"
audio = self.client.generate( audio = self.client.text_to_speech.convert(
voice_id="nPczCjzI2devNBz1zQrb",
model_id="eleven_multilingual_v2",
text=text, text=text,
model="eleven_multilingual_v2", output_format="mp3_44100_128"
voice="Brian",
) )
audio_data = BytesIO() audio_data = BytesIO()
for chunk in audio: for chunk in audio:

View File

@@ -0,0 +1,18 @@
from application.tts.google_tts import GoogleTTS
from application.tts.elevenlabs import ElevenlabsTTS
from application.tts.base import BaseTTS
class TTSCreator:
tts_providers = {
"google_tts": GoogleTTS,
"elevenlabs": ElevenlabsTTS,
}
@classmethod
def create_tts(cls, tts_type, *args, **kwargs)-> BaseTTS:
tts_class = cls.tts_providers.get(tts_type.lower())
if not tts_class:
raise ValueError(f"No tts class found for type {tts_type}")
return tts_class(*args, **kwargs)

View File

@@ -16,12 +16,25 @@ def test_elevenlabs_text_to_speech_monkeypatched_client(monkeypatch):
class DummyClient: class DummyClient:
def __init__(self, api_key): def __init__(self, api_key):
created["api_key"] = api_key created["api_key"] = api_key
self.generate_calls = [] self.convert_calls = []
def generate(self, *, text, model, voice): class TextToSpeech:
self.generate_calls.append({"text": text, "model": model, "voice": voice}) def __init__(self, outer):
yield b"chunk-one" self._outer = outer
yield b"chunk-two"
def convert(self, *, voice_id, model_id, text, output_format):
self._outer.convert_calls.append(
{
"voice_id": voice_id,
"model_id": model_id,
"text": text,
"output_format": output_format,
}
)
yield b"chunk-one"
yield b"chunk-two"
self.text_to_speech = TextToSpeech(self)
client_module = ModuleType("elevenlabs.client") client_module = ModuleType("elevenlabs.client")
client_module.ElevenLabs = DummyClient client_module.ElevenLabs = DummyClient
@@ -35,8 +48,13 @@ def test_elevenlabs_text_to_speech_monkeypatched_client(monkeypatch):
audio_base64, lang = tts.text_to_speech("Speak") audio_base64, lang = tts.text_to_speech("Speak")
assert created["api_key"] == "api-key" assert created["api_key"] == "api-key"
assert tts.client.generate_calls == [ assert tts.client.convert_calls == [
{"text": "Speak", "model": "eleven_multilingual_v2", "voice": "Brian"} {
"voice_id": "nPczCjzI2devNBz1zQrb",
"model_id": "eleven_multilingual_v2",
"text": "Speak",
"output_format": "mp3_44100_128",
}
] ]
assert lang == "en" assert lang == "en"
assert base64.b64decode(audio_base64.encode()) == b"chunk-onechunk-two" assert base64.b64decode(audio_base64.encode()) == b"chunk-onechunk-two"

View File

@@ -0,0 +1,61 @@
import pytest
from unittest.mock import patch, MagicMock
from application.tts.tts_creator import TTSCreator
@pytest.fixture
def tts_creator():
return TTSCreator()
def test_create_google_tts(tts_creator):
# Patch the provider registry so the factory calls our mock class
with patch.dict(TTSCreator.tts_providers, {"google_tts": MagicMock()}):
mock_google_tts = TTSCreator.tts_providers["google_tts"]
instance = MagicMock()
mock_google_tts.return_value = instance
result = tts_creator.create_tts("google_tts", "arg1", key="value")
mock_google_tts.assert_called_once_with("arg1", key="value")
assert result == instance
def test_create_elevenlabs_tts(tts_creator):
# Patch the provider registry so the factory calls our mock class
with patch.dict(TTSCreator.tts_providers, {"elevenlabs": MagicMock()}):
mock_elevenlabs_tts = TTSCreator.tts_providers["elevenlabs"]
instance = MagicMock()
mock_elevenlabs_tts.return_value = instance
result = tts_creator.create_tts("elevenlabs", "voice", lang="en")
mock_elevenlabs_tts.assert_called_once_with("voice", lang="en")
assert result == instance
def test_invalid_tts_type(tts_creator):
with pytest.raises(ValueError) as excinfo:
tts_creator.create_tts("unknown_tts")
assert "No tts class found" in str(excinfo.value)
def test_tts_type_case_insensitivity(tts_creator):
# Patch the provider registry to ensure case-insensitive lookup hits our mock
with patch.dict(TTSCreator.tts_providers, {"google_tts": MagicMock()}):
mock_google_tts = TTSCreator.tts_providers["google_tts"]
instance = MagicMock()
mock_google_tts.return_value = instance
result = tts_creator.create_tts("GoOgLe_TtS")
mock_google_tts.assert_called_once_with()
assert result == instance
def test_tts_providers_integrity(tts_creator):
providers = tts_creator.tts_providers
assert "google_tts" in providers
assert "elevenlabs" in providers
assert callable(providers["google_tts"])
assert callable(providers["elevenlabs"])