mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 00:23:17 +00:00
This commit is contained in:
@@ -10,7 +10,7 @@ from application.api import api
|
|||||||
from application.api.user.base import agents_collection, storage
|
from application.api.user.base import agents_collection, storage
|
||||||
from application.api.user.tasks import store_attachment
|
from application.api.user.tasks import store_attachment
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.tts.google_tts import GoogleTTS
|
from application.tts.tts_creator import TTSCreator
|
||||||
from application.utils import safe_filename
|
from application.utils import safe_filename
|
||||||
|
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ class TextToSpeech(Resource):
|
|||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
text = data["text"]
|
text = data["text"]
|
||||||
try:
|
try:
|
||||||
tts_instance = GoogleTTS()
|
tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER)
|
||||||
audio_base64, detected_language = tts_instance.text_to_speech(text)
|
audio_base64, detected_language = tts_instance.text_to_speech(text)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
|
|||||||
@@ -130,6 +130,7 @@ class Settings(BaseSettings):
|
|||||||
# Encryption settings
|
# Encryption settings
|
||||||
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
|
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
|
||||||
|
|
||||||
|
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
|
||||||
ELEVENLABS_API_KEY: Optional[str] = None
|
ELEVENLABS_API_KEY: Optional[str] = None
|
||||||
|
|
||||||
path = Path(__file__).parent.parent.absolute()
|
path = Path(__file__).parent.parent.absolute()
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ ebooklib==0.18
|
|||||||
escodegen==1.0.11
|
escodegen==1.0.11
|
||||||
esprima==4.0.1
|
esprima==4.0.1
|
||||||
esutils==1.0.1
|
esutils==1.0.1
|
||||||
|
elevenlabs==2.17.0
|
||||||
Flask==3.1.1
|
Flask==3.1.1
|
||||||
faiss-cpu==1.9.0.post1
|
faiss-cpu==1.9.0.post1
|
||||||
fastmcp==2.11.0
|
fastmcp==2.11.0
|
||||||
|
|||||||
@@ -15,10 +15,11 @@ class ElevenlabsTTS(BaseTTS):
|
|||||||
|
|
||||||
def text_to_speech(self, text):
|
def text_to_speech(self, text):
|
||||||
lang = "en"
|
lang = "en"
|
||||||
audio = self.client.generate(
|
audio = self.client.text_to_speech.convert(
|
||||||
|
voice_id="nPczCjzI2devNBz1zQrb",
|
||||||
|
model_id="eleven_multilingual_v2",
|
||||||
text=text,
|
text=text,
|
||||||
model="eleven_multilingual_v2",
|
output_format="mp3_44100_128"
|
||||||
voice="Brian",
|
|
||||||
)
|
)
|
||||||
audio_data = BytesIO()
|
audio_data = BytesIO()
|
||||||
for chunk in audio:
|
for chunk in audio:
|
||||||
|
|||||||
18
application/tts/tts_creator.py
Normal file
18
application/tts/tts_creator.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from application.tts.google_tts import GoogleTTS
|
||||||
|
from application.tts.elevenlabs import ElevenlabsTTS
|
||||||
|
from application.tts.base import BaseTTS
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TTSCreator:
|
||||||
|
tts_providers = {
|
||||||
|
"google_tts": GoogleTTS,
|
||||||
|
"elevenlabs": ElevenlabsTTS,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_tts(cls, tts_type, *args, **kwargs)-> BaseTTS:
|
||||||
|
tts_class = cls.tts_providers.get(tts_type.lower())
|
||||||
|
if not tts_class:
|
||||||
|
raise ValueError(f"No tts class found for type {tts_type}")
|
||||||
|
return tts_class(*args, **kwargs)
|
||||||
@@ -16,12 +16,25 @@ def test_elevenlabs_text_to_speech_monkeypatched_client(monkeypatch):
|
|||||||
class DummyClient:
|
class DummyClient:
|
||||||
def __init__(self, api_key):
|
def __init__(self, api_key):
|
||||||
created["api_key"] = api_key
|
created["api_key"] = api_key
|
||||||
self.generate_calls = []
|
self.convert_calls = []
|
||||||
|
|
||||||
def generate(self, *, text, model, voice):
|
class TextToSpeech:
|
||||||
self.generate_calls.append({"text": text, "model": model, "voice": voice})
|
def __init__(self, outer):
|
||||||
yield b"chunk-one"
|
self._outer = outer
|
||||||
yield b"chunk-two"
|
|
||||||
|
def convert(self, *, voice_id, model_id, text, output_format):
|
||||||
|
self._outer.convert_calls.append(
|
||||||
|
{
|
||||||
|
"voice_id": voice_id,
|
||||||
|
"model_id": model_id,
|
||||||
|
"text": text,
|
||||||
|
"output_format": output_format,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
yield b"chunk-one"
|
||||||
|
yield b"chunk-two"
|
||||||
|
|
||||||
|
self.text_to_speech = TextToSpeech(self)
|
||||||
|
|
||||||
client_module = ModuleType("elevenlabs.client")
|
client_module = ModuleType("elevenlabs.client")
|
||||||
client_module.ElevenLabs = DummyClient
|
client_module.ElevenLabs = DummyClient
|
||||||
@@ -35,8 +48,13 @@ def test_elevenlabs_text_to_speech_monkeypatched_client(monkeypatch):
|
|||||||
audio_base64, lang = tts.text_to_speech("Speak")
|
audio_base64, lang = tts.text_to_speech("Speak")
|
||||||
|
|
||||||
assert created["api_key"] == "api-key"
|
assert created["api_key"] == "api-key"
|
||||||
assert tts.client.generate_calls == [
|
assert tts.client.convert_calls == [
|
||||||
{"text": "Speak", "model": "eleven_multilingual_v2", "voice": "Brian"}
|
{
|
||||||
|
"voice_id": "nPczCjzI2devNBz1zQrb",
|
||||||
|
"model_id": "eleven_multilingual_v2",
|
||||||
|
"text": "Speak",
|
||||||
|
"output_format": "mp3_44100_128",
|
||||||
|
}
|
||||||
]
|
]
|
||||||
assert lang == "en"
|
assert lang == "en"
|
||||||
assert base64.b64decode(audio_base64.encode()) == b"chunk-onechunk-two"
|
assert base64.b64decode(audio_base64.encode()) == b"chunk-onechunk-two"
|
||||||
|
|||||||
61
tests/tts/test_tts_creator.py
Normal file
61
tests/tts/test_tts_creator.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from application.tts.tts_creator import TTSCreator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tts_creator():
|
||||||
|
return TTSCreator()
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_google_tts(tts_creator):
|
||||||
|
# Patch the provider registry so the factory calls our mock class
|
||||||
|
with patch.dict(TTSCreator.tts_providers, {"google_tts": MagicMock()}):
|
||||||
|
mock_google_tts = TTSCreator.tts_providers["google_tts"]
|
||||||
|
instance = MagicMock()
|
||||||
|
mock_google_tts.return_value = instance
|
||||||
|
|
||||||
|
result = tts_creator.create_tts("google_tts", "arg1", key="value")
|
||||||
|
|
||||||
|
mock_google_tts.assert_called_once_with("arg1", key="value")
|
||||||
|
assert result == instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_elevenlabs_tts(tts_creator):
|
||||||
|
# Patch the provider registry so the factory calls our mock class
|
||||||
|
with patch.dict(TTSCreator.tts_providers, {"elevenlabs": MagicMock()}):
|
||||||
|
mock_elevenlabs_tts = TTSCreator.tts_providers["elevenlabs"]
|
||||||
|
instance = MagicMock()
|
||||||
|
mock_elevenlabs_tts.return_value = instance
|
||||||
|
|
||||||
|
result = tts_creator.create_tts("elevenlabs", "voice", lang="en")
|
||||||
|
|
||||||
|
mock_elevenlabs_tts.assert_called_once_with("voice", lang="en")
|
||||||
|
assert result == instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_tts_type(tts_creator):
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
tts_creator.create_tts("unknown_tts")
|
||||||
|
assert "No tts class found" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tts_type_case_insensitivity(tts_creator):
|
||||||
|
# Patch the provider registry to ensure case-insensitive lookup hits our mock
|
||||||
|
with patch.dict(TTSCreator.tts_providers, {"google_tts": MagicMock()}):
|
||||||
|
mock_google_tts = TTSCreator.tts_providers["google_tts"]
|
||||||
|
instance = MagicMock()
|
||||||
|
mock_google_tts.return_value = instance
|
||||||
|
|
||||||
|
result = tts_creator.create_tts("GoOgLe_TtS")
|
||||||
|
|
||||||
|
mock_google_tts.assert_called_once_with()
|
||||||
|
assert result == instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_tts_providers_integrity(tts_creator):
|
||||||
|
providers = tts_creator.tts_providers
|
||||||
|
assert "google_tts" in providers
|
||||||
|
assert "elevenlabs" in providers
|
||||||
|
assert callable(providers["google_tts"])
|
||||||
|
assert callable(providers["elevenlabs"])
|
||||||
Reference in New Issue
Block a user