refactor(stt): remove speaker identification (diarization) from transcriber

Removes the speaker diarization pipeline and alignment model from the STT module to reduce resource usage and complexity.
The transcription API remains compatible by returning 'Unknown' as the speaker ID for all transcribed segments.

- Removed DiarizationPipeline and align_model from Transcriber
- Simplified transcribe method to return basic transcription segments
- Updated logging and docstrings to reflect changes
This commit is contained in:
2026-06-06 20:52:04 -07:00
parent 01b049cf37
commit 284c50acd8
+9 -36
View File
@@ -1,9 +1,7 @@
import logging
import os
import numpy as np
import whisperx
from whisperx.diarize import DiarizationPipeline
# Do not call basicConfig here, as it's called in the orchestrator
logger = logging.getLogger(__name__)
@@ -11,14 +9,14 @@ logger = logging.getLogger(__name__)
class Transcriber:
"""
Converts audio chunks (numpy arrays) into text and identifies speakers using WhisperX.
Converts audio chunks (numpy arrays) into text using WhisperX.
"""
def __init__(
self, model_size="base", device="cpu", compute_type="int8", language="en"
):
"""
Initializes the WhisperX model and diarization pipeline.
Initializes the WhisperX model.
Args:
model_size (str): The size of the model to use (e.g., "tiny", "base", "small").
@@ -39,28 +37,20 @@ class Transcriber:
model_size, device=device, compute_type=compute_type
)
# Load alignment model (required for accurate speaker assignment)
# model_dir=None allows automatic model selection based on the language
self.align_model, self.align_metadata = whisperx.load_align_model(
device=device, model_dir=None, language_code=self.language
)
self.diarize_model = DiarizationPipeline()
logger.info("WhisperX and Diarization models loaded successfully.")
logger.info("WhisperX model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load WhisperX models: {e}")
raise
def transcribe(self, audio_chunk):
"""
Transcribes an audio chunk and performs speaker diarization.
Transcribes an audio chunk.
Args:
audio_chunk (np.ndarray): The audio data as a numpy array.
Returns:
list: A list of tuples (speaker_id, text).
list: A list of tuples (speaker_id, text, start, end).
"""
if audio_chunk is None:
return []
@@ -73,35 +63,18 @@ class Transcriber:
# batch_size is set to 16 for efficiency; can be adjusted based on VRAM
result = self.model.transcribe(audio, batch_size=16)
# 2. Perform alignment
# Alignment is necessary for the assign_words_to_speakers step
result_a = whisperx.align(
result["segments"],
self.align_model,
self.align_metadata,
audio,
self.device,
)
# 3. Perform diarization
diarize_segments = self.diarize_model(audio)
# 4. Align transcription segments with speakers
result_final = whisperx.assign_word_speakers(diarize_segments, result_a)
# Extract (speaker_id, text, start, end) tuples from the final result
# Extract ("Unknown", text, start, end) tuples from the transcription result
output = []
for segment in result_final.get("segments", []):
speaker = segment.get("speaker", "Unknown")
for segment in result.get("segments", []):
text = segment.get("text", "").strip()
start = segment.get("start", 0.0)
end = segment.get("end", 0.0)
if text:
output.append((speaker, text, start, end))
output.append(("Unknown", text, start, end))
return output
except Exception as e:
logger.error(f"Transcription/Diarization error: {e}")
logger.error(f"Transcription error: {e}")
return []
def close(self):