284c50acd8
Removes the speaker diarization pipeline and alignment model from the STT module to reduce resource usage and complexity. The transcription API remains compatible by returning 'Unknown' as the speaker ID for all transcribed segments. - Removed DiarizationPipeline and align_model from Transcriber - Simplified transcribe method to return basic transcription segments - Updated logging and docstrings to reflect changes
85 lines
2.6 KiB
Python
85 lines
2.6 KiB
Python
import logging
|
|
|
|
import numpy as np
|
|
import whisperx
|
|
|
|
# Do not call basicConfig here, as it's called in the orchestrator
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Transcriber:
|
|
"""
|
|
Converts audio chunks (numpy arrays) into text using WhisperX.
|
|
"""
|
|
|
|
def __init__(
|
|
self, model_size="base", device="cpu", compute_type="int8", language="en"
|
|
):
|
|
"""
|
|
Initializes the WhisperX model.
|
|
|
|
Args:
|
|
model_size (str): The size of the model to use (e.g., "tiny", "base", "small").
|
|
device (str): The device to run the model on ("cpu" or "cuda").
|
|
compute_type (str): The compute type to use (e.g., "int8", "float16").
|
|
language (str): The language code for alignment (e.g., "en").
|
|
"""
|
|
self.device = device
|
|
self.compute_type = compute_type
|
|
self.language = language
|
|
|
|
logger.info(
|
|
f"Loading WhisperX model: {model_size} on {device} ({compute_type})..."
|
|
)
|
|
try:
|
|
# Load transcription model
|
|
self.model = whisperx.load_model(
|
|
model_size, device=device, compute_type=compute_type
|
|
)
|
|
|
|
logger.info("WhisperX model loaded successfully.")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load WhisperX models: {e}")
|
|
raise
|
|
|
|
def transcribe(self, audio_chunk):
|
|
"""
|
|
Transcribes an audio chunk.
|
|
|
|
Args:
|
|
audio_chunk (np.ndarray): The audio data as a numpy array.
|
|
|
|
Returns:
|
|
list: A list of tuples (speaker_id, text, start, end).
|
|
"""
|
|
if audio_chunk is None:
|
|
return []
|
|
|
|
try:
|
|
# WhisperX expects audio in float32 and 1D array
|
|
audio = audio_chunk.astype("float32").flatten()
|
|
|
|
# 1. Perform transcription
|
|
# batch_size is set to 16 for efficiency; can be adjusted based on VRAM
|
|
result = self.model.transcribe(audio, batch_size=16)
|
|
|
|
# Extract ("Unknown", text, start, end) tuples from the transcription result
|
|
output = []
|
|
for segment in result.get("segments", []):
|
|
text = segment.get("text", "").strip()
|
|
start = segment.get("start", 0.0)
|
|
end = segment.get("end", 0.0)
|
|
if text:
|
|
output.append(("Unknown", text, start, end))
|
|
|
|
return output
|
|
except Exception as e:
|
|
logger.error(f"Transcription error: {e}")
|
|
return []
|
|
|
|
def close(self):
|
|
"""
|
|
Explicitly release model resources if necessary.
|
|
"""
|
|
pass
|