diff --git a/whisperlivekit/diarization/sortformer_backend.py b/whisperlivekit/diarization/sortformer_backend.py index 3072eae..84652a3 100644 --- a/whisperlivekit/diarization/sortformer_backend.py +++ b/whisperlivekit/diarization/sortformer_backend.py @@ -60,11 +60,15 @@ class SortformerDiarization: self.diar_model = SortformerEncLabelModel.from_pretrained(model_name) self.diar_model.eval() - if torch.cuda.is_available(): - self.diar_model.to(torch.device("cuda")) - logger.info("Using CUDA for Sortformer model") - else: - logger.info("Using CPU for Sortformer model") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.diar_model.to(device) + + ## to test + # for name, param in self.diar_model.named_parameters(): + # if param.device != device: + # raise RuntimeError(f"Parameter {name} is on {param.device} but should be on {device}") + + logger.info(f"Using {device.type.upper()} for Sortformer model") self.diar_model.sortformer_modules.chunk_len = 10 self.diar_model.sortformer_modules.subsampling_factor = 10 @@ -187,22 +191,25 @@ class SortformerDiarizationOnline: audio = self.buffer_audio[:threshold] self.buffer_audio = self.buffer_audio[threshold:] - audio_signal_chunk = torch.tensor(audio).unsqueeze(0).to(self.diar_model.device) - audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(self.diar_model.device) + device = self.diar_model.device + audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0) + audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device) processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features( audio_signal_chunk, audio_signal_length_chunk ) + processed_signal_chunk = processed_signal_chunk.to(device) + processed_signal_length_chunk = processed_signal_length_chunk.to(device) if self._previous_chunk_features is not None: - to_add = self._previous_chunk_features[:, :, -99:] - total_features = torch.concat([to_add, processed_signal_chunk], dim=2) + to_add = self._previous_chunk_features[:, :, -99:].to(device) + total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device) else: - total_features = processed_signal_chunk + total_features = processed_signal_chunk.to(device) - self._previous_chunk_features = processed_signal_chunk + self._previous_chunk_features = processed_signal_chunk.to(device) - chunk_feat_seq_t = torch.transpose(total_features, 1, 2) + chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device) with torch.inference_mode(): left_offset = 8 if self._chunk_index > 0 else 0 @@ -210,7 +217,7 @@ class SortformerDiarizationOnline: self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step( processed_signal=chunk_feat_seq_t, - processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]), + processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device), streaming_state=self.streaming_state, total_preds=self.total_preds, left_offset=left_offset,