Migrate to WhisperX for speaker diarization
Implement a sliding window audio buffer and update the transcriber to use WhisperX for transcription, alignment, and speaker identification. Update the pipeline to handle and store speaker-attributed transcripts. Additionally, update the LLM processor's reasoning parameter to "enable_thinking".
This commit is contained in:
+64
-22
@@ -1,6 +1,9 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from faster_whisper import WhisperModel
|
||||
import numpy as np
|
||||
import whisperx
|
||||
from whisperx.diarize import DiarizationPipeline
|
||||
|
||||
# Do not call basicConfig here, as it's called in the orchestrator
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -8,62 +11,101 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class Transcriber:
|
||||
"""
|
||||
Converts audio chunks (numpy arrays) into text using faster-whisper.
|
||||
Converts audio chunks (numpy arrays) into text and identifies speakers using WhisperX.
|
||||
"""
|
||||
|
||||
def __init__(self, model_size="base", device="cpu", compute_type="int8"):
|
||||
def __init__(
|
||||
self, model_size="base", device="cpu", compute_type="int8", language="en"
|
||||
):
|
||||
"""
|
||||
Initializes the faster-whisper model.
|
||||
Initializes the WhisperX model and diarization pipeline.
|
||||
|
||||
Args:
|
||||
model_size (str): The size of the model to use (e.g., "tiny", "base", "small").
|
||||
device (str): The device to run the model on ("cpu" or "cuda").
|
||||
compute_type (str): The compute type to use (e.g., "int8", "float16").
|
||||
language (str): The language code for alignment (e.g., "en").
|
||||
"""
|
||||
self.device = device
|
||||
self.compute_type = compute_type
|
||||
self.language = language
|
||||
|
||||
logger.info(
|
||||
f"Loading faster-whisper model: {model_size} on {device} ({compute_type})..."
|
||||
f"Loading WhisperX model: {model_size} on {device} ({compute_type})..."
|
||||
)
|
||||
try:
|
||||
self.model = WhisperModel(
|
||||
# Load transcription model
|
||||
self.model = whisperx.load_model(
|
||||
model_size, device=device, compute_type=compute_type
|
||||
)
|
||||
logger.info("Model loaded successfully.")
|
||||
|
||||
# Load alignment model (required for accurate speaker assignment)
|
||||
# model_dir=None allows automatic model selection based on the language
|
||||
self.align_model, self.align_metadata = whisperx.load_align_model(
|
||||
device=device, model_dir=None, language_code=self.language
|
||||
)
|
||||
|
||||
self.diarize_model = DiarizationPipeline()
|
||||
|
||||
logger.info("WhisperX and Diarization models loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load faster-whisper model: {e}")
|
||||
logger.error(f"Failed to load WhisperX models: {e}")
|
||||
raise
|
||||
|
||||
def transcribe(self, audio_chunk):
|
||||
"""
|
||||
Transcribes a single audio chunk.
|
||||
Transcribes an audio chunk and performs speaker diarization.
|
||||
|
||||
Args:
|
||||
audio_chunk (np.ndarray): The audio data as a numpy array.
|
||||
|
||||
Returns:
|
||||
str: The transcribed text.
|
||||
list: A list of tuples (speaker_id, text).
|
||||
"""
|
||||
if audio_chunk is None:
|
||||
return ""
|
||||
return []
|
||||
|
||||
try:
|
||||
# faster-whisper expects audio in float32 and 1D array
|
||||
audio_data = audio_chunk.astype("float32").flatten()
|
||||
# WhisperX expects audio in float32 and 1D array
|
||||
audio = audio_chunk.astype("float32").flatten()
|
||||
|
||||
# Transcribe the audio
|
||||
segments, info = self.model.transcribe(audio_data, beam_size=5)
|
||||
# 1. Perform transcription
|
||||
# batch_size is set to 16 for efficiency; can be adjusted based on VRAM
|
||||
result = self.model.transcribe(audio, batch_size=16)
|
||||
|
||||
# Combine segments into a single string
|
||||
text = " ".join([segment.text.strip() for segment in segments])
|
||||
# 2. Perform alignment
|
||||
# Alignment is necessary for the assign_words_to_speakers step
|
||||
result_a = whisperx.align(
|
||||
result["segments"],
|
||||
self.align_model,
|
||||
self.align_metadata,
|
||||
audio,
|
||||
self.device,
|
||||
)
|
||||
|
||||
return text.strip()
|
||||
# 3. Perform diarization
|
||||
diarize_segments = self.diarize_model(audio)
|
||||
|
||||
# 4. Align transcription segments with speakers
|
||||
result_final = whisperx.assign_word_speakers(diarize_segments, result_a)
|
||||
|
||||
# Extract (speaker_id, text, start, end) tuples from the final result
|
||||
output = []
|
||||
for segment in result_final.get("segments", []):
|
||||
speaker = segment.get("speaker", "Unknown")
|
||||
text = segment.get("text", "").strip()
|
||||
start = segment.get("start", 0.0)
|
||||
end = segment.get("end", 0.0)
|
||||
if text:
|
||||
output.append((speaker, text, start, end))
|
||||
|
||||
return output
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {e}")
|
||||
return ""
|
||||
logger.error(f"Transcription/Diarization error: {e}")
|
||||
return []
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Explicitly release model resources if necessary.
|
||||
"""
|
||||
# faster-whisper's WhisperModel doesn't have a standard close(),
|
||||
# but we'll provide this for consistency.
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user