mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 00:23:17 +00:00
This commit is contained in:
@@ -10,7 +10,7 @@ from application.api import api
|
||||
from application.api.user.base import agents_collection, storage
|
||||
from application.api.user.tasks import store_attachment
|
||||
from application.core.settings import settings
|
||||
from application.tts.google_tts import GoogleTTS
|
||||
from application.tts.tts_creator import TTSCreator
|
||||
from application.utils import safe_filename
|
||||
|
||||
|
||||
@@ -133,7 +133,7 @@ class TextToSpeech(Resource):
|
||||
data = request.get_json()
|
||||
text = data["text"]
|
||||
try:
|
||||
tts_instance = GoogleTTS()
|
||||
tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER)
|
||||
audio_base64, detected_language = tts_instance.text_to_speech(text)
|
||||
return make_response(
|
||||
jsonify(
|
||||
|
||||
@@ -130,6 +130,7 @@ class Settings(BaseSettings):
|
||||
# Encryption settings
|
||||
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
|
||||
|
||||
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
|
||||
ELEVENLABS_API_KEY: Optional[str] = None
|
||||
|
||||
path = Path(__file__).parent.parent.absolute()
|
||||
|
||||
@@ -10,6 +10,7 @@ ebooklib==0.18
|
||||
escodegen==1.0.11
|
||||
esprima==4.0.1
|
||||
esutils==1.0.1
|
||||
elevenlabs==2.17.0
|
||||
Flask==3.1.1
|
||||
faiss-cpu==1.9.0.post1
|
||||
fastmcp==2.11.0
|
||||
|
||||
@@ -15,10 +15,11 @@ class ElevenlabsTTS(BaseTTS):
|
||||
|
||||
def text_to_speech(self, text):
|
||||
lang = "en"
|
||||
audio = self.client.generate(
|
||||
audio = self.client.text_to_speech.convert(
|
||||
voice_id="nPczCjzI2devNBz1zQrb",
|
||||
model_id="eleven_multilingual_v2",
|
||||
text=text,
|
||||
model="eleven_multilingual_v2",
|
||||
voice="Brian",
|
||||
output_format="mp3_44100_128"
|
||||
)
|
||||
audio_data = BytesIO()
|
||||
for chunk in audio:
|
||||
|
||||
18
application/tts/tts_creator.py
Normal file
18
application/tts/tts_creator.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from application.tts.google_tts import GoogleTTS
|
||||
from application.tts.elevenlabs import ElevenlabsTTS
|
||||
from application.tts.base import BaseTTS
|
||||
|
||||
|
||||
|
||||
class TTSCreator:
|
||||
tts_providers = {
|
||||
"google_tts": GoogleTTS,
|
||||
"elevenlabs": ElevenlabsTTS,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_tts(cls, tts_type, *args, **kwargs)-> BaseTTS:
|
||||
tts_class = cls.tts_providers.get(tts_type.lower())
|
||||
if not tts_class:
|
||||
raise ValueError(f"No tts class found for type {tts_type}")
|
||||
return tts_class(*args, **kwargs)
|
||||
@@ -16,13 +16,26 @@ def test_elevenlabs_text_to_speech_monkeypatched_client(monkeypatch):
|
||||
class DummyClient:
|
||||
def __init__(self, api_key):
|
||||
created["api_key"] = api_key
|
||||
self.generate_calls = []
|
||||
self.convert_calls = []
|
||||
|
||||
def generate(self, *, text, model, voice):
|
||||
self.generate_calls.append({"text": text, "model": model, "voice": voice})
|
||||
class TextToSpeech:
|
||||
def __init__(self, outer):
|
||||
self._outer = outer
|
||||
|
||||
def convert(self, *, voice_id, model_id, text, output_format):
|
||||
self._outer.convert_calls.append(
|
||||
{
|
||||
"voice_id": voice_id,
|
||||
"model_id": model_id,
|
||||
"text": text,
|
||||
"output_format": output_format,
|
||||
}
|
||||
)
|
||||
yield b"chunk-one"
|
||||
yield b"chunk-two"
|
||||
|
||||
self.text_to_speech = TextToSpeech(self)
|
||||
|
||||
client_module = ModuleType("elevenlabs.client")
|
||||
client_module.ElevenLabs = DummyClient
|
||||
package_module = ModuleType("elevenlabs")
|
||||
@@ -35,8 +48,13 @@ def test_elevenlabs_text_to_speech_monkeypatched_client(monkeypatch):
|
||||
audio_base64, lang = tts.text_to_speech("Speak")
|
||||
|
||||
assert created["api_key"] == "api-key"
|
||||
assert tts.client.generate_calls == [
|
||||
{"text": "Speak", "model": "eleven_multilingual_v2", "voice": "Brian"}
|
||||
assert tts.client.convert_calls == [
|
||||
{
|
||||
"voice_id": "nPczCjzI2devNBz1zQrb",
|
||||
"model_id": "eleven_multilingual_v2",
|
||||
"text": "Speak",
|
||||
"output_format": "mp3_44100_128",
|
||||
}
|
||||
]
|
||||
assert lang == "en"
|
||||
assert base64.b64decode(audio_base64.encode()) == b"chunk-onechunk-two"
|
||||
|
||||
61
tests/tts/test_tts_creator.py
Normal file
61
tests/tts/test_tts_creator.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from application.tts.tts_creator import TTSCreator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tts_creator():
|
||||
return TTSCreator()
|
||||
|
||||
|
||||
def test_create_google_tts(tts_creator):
|
||||
# Patch the provider registry so the factory calls our mock class
|
||||
with patch.dict(TTSCreator.tts_providers, {"google_tts": MagicMock()}):
|
||||
mock_google_tts = TTSCreator.tts_providers["google_tts"]
|
||||
instance = MagicMock()
|
||||
mock_google_tts.return_value = instance
|
||||
|
||||
result = tts_creator.create_tts("google_tts", "arg1", key="value")
|
||||
|
||||
mock_google_tts.assert_called_once_with("arg1", key="value")
|
||||
assert result == instance
|
||||
|
||||
|
||||
def test_create_elevenlabs_tts(tts_creator):
|
||||
# Patch the provider registry so the factory calls our mock class
|
||||
with patch.dict(TTSCreator.tts_providers, {"elevenlabs": MagicMock()}):
|
||||
mock_elevenlabs_tts = TTSCreator.tts_providers["elevenlabs"]
|
||||
instance = MagicMock()
|
||||
mock_elevenlabs_tts.return_value = instance
|
||||
|
||||
result = tts_creator.create_tts("elevenlabs", "voice", lang="en")
|
||||
|
||||
mock_elevenlabs_tts.assert_called_once_with("voice", lang="en")
|
||||
assert result == instance
|
||||
|
||||
|
||||
def test_invalid_tts_type(tts_creator):
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
tts_creator.create_tts("unknown_tts")
|
||||
assert "No tts class found" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_tts_type_case_insensitivity(tts_creator):
|
||||
# Patch the provider registry to ensure case-insensitive lookup hits our mock
|
||||
with patch.dict(TTSCreator.tts_providers, {"google_tts": MagicMock()}):
|
||||
mock_google_tts = TTSCreator.tts_providers["google_tts"]
|
||||
instance = MagicMock()
|
||||
mock_google_tts.return_value = instance
|
||||
|
||||
result = tts_creator.create_tts("GoOgLe_TtS")
|
||||
|
||||
mock_google_tts.assert_called_once_with()
|
||||
assert result == instance
|
||||
|
||||
|
||||
def test_tts_providers_integrity(tts_creator):
|
||||
providers = tts_creator.tts_providers
|
||||
assert "google_tts" in providers
|
||||
assert "elevenlabs" in providers
|
||||
assert callable(providers["google_tts"])
|
||||
assert callable(providers["elevenlabs"])
|
||||
Reference in New Issue
Block a user