diff --git a/whisperlivekit/__init__.py b/whisperlivekit/__init__.py index 50fb50c..5e7e884 100644 --- a/whisperlivekit/__init__.py +++ b/whisperlivekit/__init__.py @@ -1,7 +1,7 @@ from .audio_processor import AudioProcessor from .core import TranscriptionEngine from .parse_args import parse_args -from .web.web_interface import get_web_interface_html, get_inline_ui_html +from .web.web_interface import get_inline_ui_html, get_web_interface_html __all__ = [ "TranscriptionEngine", diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index ee53c54..64b66ba 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -1,14 +1,20 @@ import asyncio -import numpy as np -from time import time import logging import traceback -from typing import Optional, Union, List, Any, AsyncGenerator -from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker -from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory -from whisperlivekit.silero_vad_iterator import FixedVADIterator +from time import time +from typing import Any, AsyncGenerator, List, Optional, Union + +import numpy as np + +from whisperlivekit.core import (TranscriptionEngine, + online_diarization_factory, online_factory, + online_translation_factory) from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState +from whisperlivekit.silero_vad_iterator import FixedVADIterator +from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData, + Line, Silence, State, Transcript) from whisperlivekit.tokens_alignment import TokensAlignment + logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) diff --git a/whisperlivekit/basic_server.py b/whisperlivekit/basic_server.py index c1b682e..1694e0c 100644 --- a/whisperlivekit/basic_server.py +++ b/whisperlivekit/basic_server.py @@ -1,10 +1,13 @@ -from contextlib import asynccontextmanager -from fastapi import FastAPI, WebSocket, WebSocketDisconnect -from fastapi.responses import HTMLResponse -from fastapi.middleware.cors import CORSMiddleware -from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args import asyncio import logging +from contextlib import asynccontextmanager + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import HTMLResponse + +from whisperlivekit import (AudioProcessor, TranscriptionEngine, + get_inline_ui_html, parse_args) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.getLogger().setLevel(logging.WARNING) diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index f80b815..33ced0d 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -1,9 +1,11 @@ +import logging +import sys +from argparse import Namespace + +from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor from whisperlivekit.local_agreement.whisper_online import backend_factory from whisperlivekit.simul_whisper import SimulStreamingASR -from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor -from argparse import Namespace -import sys -import logging + def update_with_kwargs(_dict, kwargs): _dict.update({ @@ -80,6 +82,7 @@ class TranscriptionEngine: if self.args.vac: from whisperlivekit.silero_vad_iterator import load_silero_vad + # Use ONNX if specified, otherwise use JIT (default) use_onnx = kwargs.get('vac_onnx', False) self.vac_model = load_silero_vad(onnx=use_onnx) @@ -135,7 +138,8 @@ class TranscriptionEngine: if self.args.diarization: if self.args.diarization_backend == "diart": - from whisperlivekit.diarization.diart_backend import DiartDiarization + from whisperlivekit.diarization.diart_backend import \ + DiartDiarization diart_params = { "segmentation_model": "pyannote/segmentation-3.0", "embedding_model": "pyannote/embedding", @@ -146,7 +150,8 @@ class TranscriptionEngine: **diart_params ) elif self.args.diarization_backend == "sortformer": - from whisperlivekit.diarization.sortformer_backend import SortformerDiarization + from whisperlivekit.diarization.sortformer_backend import \ + SortformerDiarization self.diarization_model = SortformerDiarization() self.translation_model = None @@ -182,7 +187,8 @@ def online_diarization_factory(args, diarization_backend): # Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended if args.diarization_backend == "sortformer": - from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline + from whisperlivekit.diarization.sortformer_backend import \ + SortformerDiarizationOnline online = SortformerDiarizationOnline(shared_model=diarization_backend) return online diff --git a/whisperlivekit/diarization/diart_backend.py b/whisperlivekit/diarization/diart_backend.py index 0525973..df5d2c7 100644 --- a/whisperlivekit/diarization/diart_backend.py +++ b/whisperlivekit/diarization/diart_backend.py @@ -1,20 +1,20 @@ import asyncio +import logging import re import threading -import numpy as np -import logging import time -from queue import SimpleQueue, Empty +from queue import Empty, SimpleQueue +from typing import Any, List, Tuple +import diart.models as m +import numpy as np from diart import SpeakerDiarization, SpeakerDiarizationConfig from diart.inference import StreamingInference -from diart.sources import AudioSource -from whisperlivekit.timed_objects import SpeakerSegment -from diart.sources import MicrophoneAudioSource -from rx.core import Observer -from typing import Tuple, Any, List +from diart.sources import AudioSource, MicrophoneAudioSource from pyannote.core import Annotation -import diart.models as m +from rx.core import Observer + +from whisperlivekit.timed_objects import SpeakerSegment logger = logging.getLogger(__name__) diff --git a/whisperlivekit/diarization/sortformer_backend.py b/whisperlivekit/diarization/sortformer_backend.py index 4fb4627..474b6dc 100644 --- a/whisperlivekit/diarization/sortformer_backend.py +++ b/whisperlivekit/diarization/sortformer_backend.py @@ -1,11 +1,12 @@ -import numpy as np -import torch import logging import threading import time import wave +from queue import Empty, SimpleQueue from typing import List, Optional -from queue import SimpleQueue, Empty + +import numpy as np +import torch from whisperlivekit.timed_objects import SpeakerSegment @@ -295,6 +296,7 @@ def extract_number(s: str) -> int: if __name__ == '__main__': import asyncio + import librosa async def main(): diff --git a/whisperlivekit/ffmpeg_manager.py b/whisperlivekit/ffmpeg_manager.py index cc0275b..406cf42 100644 --- a/whisperlivekit/ffmpeg_manager.py +++ b/whisperlivekit/ffmpeg_manager.py @@ -1,8 +1,8 @@ import asyncio +import contextlib import logging from enum import Enum -from typing import Optional, Callable -import contextlib +from typing import Callable, Optional logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) diff --git a/whisperlivekit/local_agreement/backends.py b/whisperlivekit/local_agreement/backends.py index a4c67f9..e4cb79e 100644 --- a/whisperlivekit/local_agreement/backends.py +++ b/whisperlivekit/local_agreement/backends.py @@ -1,13 +1,16 @@ -import sys -import logging import io -import soundfile as sf +import logging import math +import sys from typing import List + import numpy as np +import soundfile as sf + +from whisperlivekit.model_paths import model_path_and_type, resolve_model_path from whisperlivekit.timed_objects import ASRToken -from whisperlivekit.model_paths import resolve_model_path, model_path_and_type from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe + logger = logging.getLogger(__name__) class ASRBase: sep = " " # join transcribe words with this character (" " for whisper_timestamped, @@ -165,8 +168,8 @@ class MLXWhisper(ASRBase): sep = "" def load_model(self, model_size=None, cache_dir=None, model_dir=None): - from mlx_whisper.transcribe import ModelHolder, transcribe import mlx.core as mx + from mlx_whisper.transcribe import ModelHolder, transcribe if model_dir is not None: resolved_path = resolve_model_path(model_dir) diff --git a/whisperlivekit/local_agreement/online_asr.py b/whisperlivekit/local_agreement/online_asr.py index 26403cd..e5b8632 100644 --- a/whisperlivekit/local_agreement/online_asr.py +++ b/whisperlivekit/local_agreement/online_asr.py @@ -1,7 +1,9 @@ -import sys -import numpy as np import logging -from typing import List, Tuple, Optional +import sys +from typing import List, Optional, Tuple + +import numpy as np + from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript logger = logging.getLogger(__name__) diff --git a/whisperlivekit/local_agreement/whisper_online.py b/whisperlivekit/local_agreement/whisper_online.py index 87b43ff..433fa16 100644 --- a/whisperlivekit/local_agreement/whisper_online.py +++ b/whisperlivekit/local_agreement/whisper_online.py @@ -1,18 +1,19 @@ #!/usr/bin/env python3 -import sys -import numpy as np -import librosa -from functools import lru_cache -import time import logging import platform -from .backends import FasterWhisperASR, MLXWhisper, WhisperASR, OpenaiApiASR +import sys +import time +from functools import lru_cache + +import librosa +import numpy as np + +from whisperlivekit.backend_support import (faster_backend_available, + mlx_backend_available) +from whisperlivekit.model_paths import model_path_and_type, resolve_model_path from whisperlivekit.warmup import warmup_asr -from whisperlivekit.model_paths import resolve_model_path, model_path_and_type -from whisperlivekit.backend_support import ( - mlx_backend_available, - faster_backend_available, -) + +from .backends import FasterWhisperASR, MLXWhisper, OpenaiApiASR, WhisperASR logger = logging.getLogger(__name__) diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index b24c029..d342645 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -1,6 +1,7 @@ from argparse import ArgumentParser + def parse_args(): parser = ArgumentParser(description="Whisper FastAPI Online Server") parser.add_argument( diff --git a/whisperlivekit/silero_vad_iterator.py b/whisperlivekit/silero_vad_iterator.py index d53d056..45ca63c 100644 --- a/whisperlivekit/silero_vad_iterator.py +++ b/whisperlivekit/silero_vad_iterator.py @@ -1,8 +1,9 @@ -import torch -import numpy as np import warnings from pathlib import Path +import numpy as np +import torch + """ Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad """ diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index f3714f8..8830db7 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -1,31 +1,30 @@ -import sys -import numpy as np +import gc import logging -from typing import List, Tuple, Optional +import os import platform -from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker +import sys +from pathlib import Path +from typing import List, Optional, Tuple + +import numpy as np +import torch + +from whisperlivekit.backend_support import (faster_backend_available, + mlx_backend_available) +from whisperlivekit.model_paths import model_path_and_type, resolve_model_path +from whisperlivekit.simul_whisper.config import AlignAttConfig +from whisperlivekit.simul_whisper.simul_whisper import AlignAtt +from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript from whisperlivekit.warmup import load_file from whisperlivekit.whisper import load_model, tokenizer from whisperlivekit.whisper.audio import TOKENS_PER_SECOND -import os -import gc -from pathlib import Path -from whisperlivekit.model_paths import model_path_and_type, resolve_model_path -from whisperlivekit.backend_support import ( - mlx_backend_available, - faster_backend_available, -) - -import torch -from whisperlivekit.simul_whisper.config import AlignAttConfig -from whisperlivekit.simul_whisper.simul_whisper import AlignAtt logger = logging.getLogger(__name__) HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True) if HAS_MLX_WHISPER: - from .mlx_encoder import mlx_model_mapping, load_mlx_encoder + from .mlx_encoder import load_mlx_encoder, mlx_model_mapping else: mlx_model_mapping = {} HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER) diff --git a/whisperlivekit/simul_whisper/beam.py b/whisperlivekit/simul_whisper/beam.py index cf61be7..fba600d 100644 --- a/whisperlivekit/simul_whisper/beam.py +++ b/whisperlivekit/simul_whisper/beam.py @@ -1,5 +1,6 @@ from whisperlivekit.whisper.decoding import PyTorchInference + # extention of PyTorchInference for beam search class BeamPyTorchInference(PyTorchInference): diff --git a/whisperlivekit/simul_whisper/config.py b/whisperlivekit/simul_whisper/config.py index 2562ce0..1897aac 100644 --- a/whisperlivekit/simul_whisper/config.py +++ b/whisperlivekit/simul_whisper/config.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field from typing import Literal + @dataclass class AlignAttConfig(): eval_data_path: str = "tmp" diff --git a/whisperlivekit/simul_whisper/mlx_encoder.py b/whisperlivekit/simul_whisper/mlx_encoder.py index 441166b..c9b0cd5 100644 --- a/whisperlivekit/simul_whisper/mlx_encoder.py +++ b/whisperlivekit/simul_whisper/mlx_encoder.py @@ -5,7 +5,6 @@ import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download from mlx.utils import tree_unflatten - from mlx_whisper import whisper mlx_model_mapping = { diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index 61d93f3..e207249 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -1,33 +1,34 @@ -import os import logging - -import torch -import torch.nn.functional as F -import numpy as np - -from whisperlivekit.whisper import DecodingOptions, tokenizer -from .config import AlignAttConfig -from whisperlivekit.timed_objects import ASRToken -from whisperlivekit.whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES -from whisperlivekit.whisper.timing import median_filter -from whisperlivekit.whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens -from .beam import BeamPyTorchInference -from .eow_detection import fire_at_boundary, load_cif import os from time import time -from .token_buffer import TokenBuffer -from whisperlivekit.backend_support import ( - mlx_backend_available, - faster_backend_available, -) + +import numpy as np +import torch +import torch.nn.functional as F + +from whisperlivekit.backend_support import (faster_backend_available, + mlx_backend_available) +from whisperlivekit.timed_objects import ASRToken +from whisperlivekit.whisper import DecodingOptions, tokenizer +from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES, + TOKENS_PER_SECOND, + log_mel_spectrogram, pad_or_trim) +from whisperlivekit.whisper.decoding import (BeamSearchDecoder, GreedyDecoder, + SuppressTokens) +from whisperlivekit.whisper.timing import median_filter from ..timed_objects import PUNCTUATION_MARKS +from .beam import BeamPyTorchInference +from .config import AlignAttConfig +from .eow_detection import fire_at_boundary, load_cif +from .token_buffer import TokenBuffer DEC_PAD = 50257 logger = logging.getLogger(__name__) if mlx_backend_available(): - from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram + from mlx_whisper.audio import \ + log_mel_spectrogram as mlx_log_mel_spectrogram from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim if faster_backend_available(): diff --git a/whisperlivekit/simul_whisper/token_buffer.py b/whisperlivekit/simul_whisper/token_buffer.py index 1146591..c7ace39 100644 --- a/whisperlivekit/simul_whisper/token_buffer.py +++ b/whisperlivekit/simul_whisper/token_buffer.py @@ -1,5 +1,8 @@ -import torch import sys + +import torch + + class TokenBuffer: def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]): diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index dc2a729..23dfaec 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field -from typing import Optional, List, Union, Dict, Any from datetime import timedelta +from typing import Any, Dict, List, Optional, Union PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'} diff --git a/whisperlivekit/tokens_alignment.py b/whisperlivekit/tokens_alignment.py index dd72913..26e7e73 100644 --- a/whisperlivekit/tokens_alignment.py +++ b/whisperlivekit/tokens_alignment.py @@ -1,7 +1,9 @@ from time import time -from typing import Optional, List, Tuple, Union, Any +from typing import Any, List, Optional, Tuple, Union -from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence, TimedText, Segment +from whisperlivekit.timed_objects import (ASRToken, Line, Segment, Silence, + SilentLine, SpeakerSegment, + TimedText) class TokensAlignment: diff --git a/whisperlivekit/warmup.py b/whisperlivekit/warmup.py index e660e4d..0397fe4 100644 --- a/whisperlivekit/warmup.py +++ b/whisperlivekit/warmup.py @@ -7,6 +7,7 @@ def load_file(warmup_file=None, timeout=5): import os import tempfile import urllib.request + import librosa if warmup_file == "": diff --git a/whisperlivekit/web/web_interface.py b/whisperlivekit/web/web_interface.py index 7d20841..d8e13bc 100644 --- a/whisperlivekit/web/web_interface.py +++ b/whisperlivekit/web/web_interface.py @@ -1,6 +1,6 @@ -import logging -import importlib.resources as resources import base64 +import importlib.resources as resources +import logging logger = logging.getLogger(__name__) @@ -96,11 +96,13 @@ def get_inline_ui_html(): if __name__ == '__main__': + import pathlib + + import uvicorn from fastapi import FastAPI from fastapi.responses import HTMLResponse - import uvicorn from starlette.staticfiles import StaticFiles - import pathlib + import whisperlivekit.web as webpkg app = FastAPI() diff --git a/whisperlivekit/whisper/__init__.py b/whisperlivekit/whisper/__init__.py index e2bbdf5..1dd3504 100644 --- a/whisperlivekit/whisper/__init__.py +++ b/whisperlivekit/whisper/__init__.py @@ -4,15 +4,17 @@ import json import os import urllib import warnings +from pathlib import Path from typing import Dict, List, Optional, Union import torch -from tqdm import tqdm -from pathlib import Path from torch import Tensor +from tqdm import tqdm -from whisperlivekit.whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim -from whisperlivekit.whisper.decoding import DecodingOptions, DecodingResult, decode, detect_language +from whisperlivekit.whisper.audio import (load_audio, log_mel_spectrogram, + pad_or_trim) +from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult, + decode, detect_language) from whisperlivekit.whisper.model import ModelDimensions, Whisper from whisperlivekit.whisper.transcribe import transcribe from whisperlivekit.whisper.version import __version__ diff --git a/whisperlivekit/whisper/decoding.py b/whisperlivekit/whisper/decoding.py index 49485d0..c494c72 100644 --- a/whisperlivekit/whisper/decoding.py +++ b/whisperlivekit/whisper/decoding.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field, replace -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, + Tuple, Union) import numpy as np import torch diff --git a/whisperlivekit/whisper/transcribe.py b/whisperlivekit/whisper/transcribe.py index 0a4cc36..7c192b6 100644 --- a/whisperlivekit/whisper/transcribe.py +++ b/whisperlivekit/whisper/transcribe.py @@ -8,28 +8,13 @@ import numpy as np import torch import tqdm -from .audio import ( - FRAMES_PER_SECOND, - HOP_LENGTH, - N_FRAMES, - N_SAMPLES, - SAMPLE_RATE, - log_mel_spectrogram, - pad_or_trim, -) +from .audio import (FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES, + SAMPLE_RATE, log_mel_spectrogram, pad_or_trim) from .decoding import DecodingOptions, DecodingResult from .timing import add_word_timestamps from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer -from .utils import ( - exact_div, - format_timestamp, - get_end, - get_writer, - make_safe, - optional_float, - optional_int, - str2bool, -) +from .utils import (exact_div, format_timestamp, get_end, get_writer, + make_safe, optional_float, optional_int, str2bool) if TYPE_CHECKING: from .model import Whisper