refactor(stt): remove speaker identification (diarization) from transcriber
Removes the speaker diarization pipeline and alignment model from the STT module to reduce resource usage and complexity. The transcription API remains compatible by returning 'Unknown' as the speaker ID for all transcribed segments. - Removed DiarizationPipeline and align_model from Transcriber - Simplified transcribe method to return basic transcription segments - Updated logging and docstrings to reflect changes
This commit is contained in:
+9
-36
@@ -1,9 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import whisperx
|
import whisperx
|
||||||
from whisperx.diarize import DiarizationPipeline
|
|
||||||
|
|
||||||
# Do not call basicConfig here, as it's called in the orchestrator
|
# Do not call basicConfig here, as it's called in the orchestrator
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -11,14 +9,14 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class Transcriber:
|
class Transcriber:
|
||||||
"""
|
"""
|
||||||
Converts audio chunks (numpy arrays) into text and identifies speakers using WhisperX.
|
Converts audio chunks (numpy arrays) into text using WhisperX.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_size="base", device="cpu", compute_type="int8", language="en"
|
self, model_size="base", device="cpu", compute_type="int8", language="en"
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the WhisperX model and diarization pipeline.
|
Initializes the WhisperX model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_size (str): The size of the model to use (e.g., "tiny", "base", "small").
|
model_size (str): The size of the model to use (e.g., "tiny", "base", "small").
|
||||||
@@ -39,28 +37,20 @@ class Transcriber:
|
|||||||
model_size, device=device, compute_type=compute_type
|
model_size, device=device, compute_type=compute_type
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load alignment model (required for accurate speaker assignment)
|
logger.info("WhisperX model loaded successfully.")
|
||||||
# model_dir=None allows automatic model selection based on the language
|
|
||||||
self.align_model, self.align_metadata = whisperx.load_align_model(
|
|
||||||
device=device, model_dir=None, language_code=self.language
|
|
||||||
)
|
|
||||||
|
|
||||||
self.diarize_model = DiarizationPipeline()
|
|
||||||
|
|
||||||
logger.info("WhisperX and Diarization models loaded successfully.")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load WhisperX models: {e}")
|
logger.error(f"Failed to load WhisperX models: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def transcribe(self, audio_chunk):
|
def transcribe(self, audio_chunk):
|
||||||
"""
|
"""
|
||||||
Transcribes an audio chunk and performs speaker diarization.
|
Transcribes an audio chunk.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio_chunk (np.ndarray): The audio data as a numpy array.
|
audio_chunk (np.ndarray): The audio data as a numpy array.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: A list of tuples (speaker_id, text).
|
list: A list of tuples (speaker_id, text, start, end).
|
||||||
"""
|
"""
|
||||||
if audio_chunk is None:
|
if audio_chunk is None:
|
||||||
return []
|
return []
|
||||||
@@ -73,35 +63,18 @@ class Transcriber:
|
|||||||
# batch_size is set to 16 for efficiency; can be adjusted based on VRAM
|
# batch_size is set to 16 for efficiency; can be adjusted based on VRAM
|
||||||
result = self.model.transcribe(audio, batch_size=16)
|
result = self.model.transcribe(audio, batch_size=16)
|
||||||
|
|
||||||
# 2. Perform alignment
|
# Extract ("Unknown", text, start, end) tuples from the transcription result
|
||||||
# Alignment is necessary for the assign_words_to_speakers step
|
|
||||||
result_a = whisperx.align(
|
|
||||||
result["segments"],
|
|
||||||
self.align_model,
|
|
||||||
self.align_metadata,
|
|
||||||
audio,
|
|
||||||
self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Perform diarization
|
|
||||||
diarize_segments = self.diarize_model(audio)
|
|
||||||
|
|
||||||
# 4. Align transcription segments with speakers
|
|
||||||
result_final = whisperx.assign_word_speakers(diarize_segments, result_a)
|
|
||||||
|
|
||||||
# Extract (speaker_id, text, start, end) tuples from the final result
|
|
||||||
output = []
|
output = []
|
||||||
for segment in result_final.get("segments", []):
|
for segment in result.get("segments", []):
|
||||||
speaker = segment.get("speaker", "Unknown")
|
|
||||||
text = segment.get("text", "").strip()
|
text = segment.get("text", "").strip()
|
||||||
start = segment.get("start", 0.0)
|
start = segment.get("start", 0.0)
|
||||||
end = segment.get("end", 0.0)
|
end = segment.get("end", 0.0)
|
||||||
if text:
|
if text:
|
||||||
output.append((speaker, text, start, end))
|
output.append(("Unknown", text, start, end))
|
||||||
|
|
||||||
return output
|
return output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Transcription/Diarization error: {e}")
|
logger.error(f"Transcription error: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user