Compare commits

..

8 Commits

Author SHA1 Message Date
charles 284c50acd8 refactor(stt): remove speaker identification (diarization) from transcriber
Removes the speaker diarization pipeline and alignment model from the STT module to reduce resource usage and complexity.
The transcription API remains compatible by returning 'Unknown' as the speaker ID for all transcribed segments.

- Removed DiarizationPipeline and align_model from Transcriber
- Simplified transcribe method to return basic transcription segments
- Updated logging and docstrings to reflect changes
2026-06-06 20:52:04 -07:00
charles 01b049cf37 Update main.py 2026-06-05 23:10:39 -07:00
charles da5ab1bb44 Refactor STT pipeline and CLI documentation
Split the STT worker into a collector and a transcription worker
to offload heavy processing to a background thread. Add the
`--whisper-model` flag and implement LLM latency logging. Expand
the README with comprehensive CLI usage instructions.
2026-05-31 15:04:41 -07:00
charles 71ecdb3468 Add LLM configuration and pipeline execution 2026-05-31 14:13:58 -07:00
charles 2858c7e235 Update tui.py 2026-05-30 23:03:04 -07:00
charles 15dfbfb467 Add LLM backend support and improve debugging observability
- Add LLM_BACKEND to environment configuration
- Implement detailed debug logging for LLM request/response cycles
- Add missing llama-index dependencies for embeddings and chroma
- Update prompt constraints to prevent lore redundancy
- Enable CUDA for transcription and set logging to DEBUG level
- Add entry point for running the orchestrator directly
- Cleanup unused comment in TUI context updates
2026-05-28 23:06:25 -07:00
charles 49127d695a Small changes 2026-05-28 22:08:00 -07:00
charles 2363cde160 Refactor LLM processor and improve async handling
Move contextual information handling from noise filtering to extraction
and centralize LLM call logic. Wrap blocking transcription and state
update calls in asyncio.to_thread to prevent event loop blocking.
Update transcriber model size to base.
2026-05-28 18:54:09 -07:00
11 changed files with 341 additions and 174 deletions
+2 -1
View File
@@ -1,7 +1,8 @@
# D&D Helpers Configuration
OPENAI_API_KEY=no-key-required
OPENAI_BASE_URL=https://vllm.tipsy.codes/v1
LLM_MODEL=Intel/gemma-4-31B-it-int4-AutoRound
LLM_MODEL=google/gemma-4-26b-a4b-it
LLM_BACKEND=vllm
#LLM_BACKEND=ollama
#LLM_MODEL=gemma:2b
WHISPER_MODEL=base
+48 -2
View File
@@ -28,8 +28,54 @@ Distill long sessions into concise highlights. Use LLMs to summarize recorded tr
## Interface & Usage
- **CLI**: The primary interface for confirming automated updates and querying current game state.
- **Text Editors**: Since data is stored in Markdown and JSON, you can use any editor (VS Code, Vim, Obsidian) to manually refine your campaign data.
### CLI
The primary interface for confirming automated updates and querying current game state.
#### Command Line Arguments
Use these flags to manage data ingestion and run the live capture pipeline.
##### RAG Ingestion
Use these flags to add external documents to the RAG (Retrieval-Augmented Generation) system.
| Flag | Description |
| :--- | :--- |
| `--ingest-pdf <path>` | Path to a PDF file to ingest |
| `--ingest-file <path>` | Path to a markdown file to ingest |
| `--ingest-dir <path>` | Path to a directory of markdown files to ingest |
##### LLM Configuration
These flags allow you to override the environment variables for the LLM backend.
| Flag | Description |
| :--- | :--- |
| `--llm-backend <backend>` | Backend to use (`openai`, `ollama`, or `vllm`) |
| `--llm-model <model>` | The model name to use |
| `--llm-api-key <key>` | API key for the LLM backend |
| `--llm-base-url <url>` | Base URL for the LLM backend |
##### Pipeline Execution
| Flag | Description |
| :--- | :--- |
| `--run-pipeline` | Starts the main orchestration pipeline (TUI + STT + LLM) |
##### Example Command
To run the live orchestration pipeline using the configuration specified in your `env.sh`, you can use:
```bash
python main.py --run-pipeline \
--llm-backend vllm \
--llm-model google/gemma-4-26b-a4b-it \
--llm-api-key no-key-required \
--whisper-model medium \
--llm-base-url https://vllm.tipsy.codes/v1
```
### Text Editors
Since data is stored in Markdown and JSON, you can use any editor (VS Code, Vim, Obsidian) to manually refine your campaign data.
## Technical Stack
+72 -3
View File
@@ -1,10 +1,15 @@
import argparse
import asyncio
import os
from src.pipeline.orchestrator import PipelineOrchestrator
from src.rag.manager import RAGManager
def main():
parser = argparse.ArgumentParser(description="D&D Helpers CLI")
# RAG Ingestion Arguments
parser.add_argument(
"--ingest-pdf",
type=str,
@@ -21,9 +26,73 @@ def main():
help="Path to a directory of markdown files to ingest into the RAG system",
)
# LLM Configuration Arguments
parser.add_argument(
"--llm-backend",
type=str,
choices=["openai", "ollama", "vllm"],
default=os.environ.get("LLM_BACKEND", "openai"),
help="LLM backend to use",
)
parser.add_argument(
"--llm-model",
type=str,
default=os.environ.get("LLM_MODEL", "gpt-4o"),
help="The model to use for processing",
)
parser.add_argument(
"--llm-api-key",
type=str,
default=os.environ.get("OPENAI_API_KEY"),
help="API key for the LLM backend",
)
parser.add_argument(
"--llm-base-url",
type=str,
default=os.environ.get("OPENAI_BASE_URL"),
help="Base URL for the LLM backend",
)
# STT Configuration Arguments
parser.add_argument(
"--whisper-model",
type=str,
default=os.environ.get("WHISPER_MODEL", "turbo"),
help="The Whisper model to use for STT",
)
# Pipeline Execution Argument
parser.add_argument(
"--run-pipeline",
action="store_true",
help="Start the main orchestration pipeline (TUI + STT + LLM)",
)
args = parser.parse_args()
rag_manager = RAGManager()
llm_config = {
"backend": args.llm_backend,
"model": args.llm_model,
"api_key": args.llm_api_key,
"base_url": args.llm_base_url,
}
# Remove None values to allow defaults to take over if not provided
llm_config = {k: v for k, v in llm_config.items() if v is not None}
if args.run_pipeline:
async def run_pipeline():
loop = asyncio.get_event_loop()
orchestrator = PipelineOrchestrator(loop, llm_config=llm_config, whisper_model=args.whisper_model)
try:
await orchestrator.run()
except KeyboardInterrupt:
orchestrator.stop()
asyncio.run(run_pipeline())
return
rag_manager = RAGManager(llm_config=llm_config)
if args.ingest_pdf:
print(f"Ingesting PDF: {args.ingest_pdf}...")
@@ -40,8 +109,8 @@ def main():
rag_manager.ingest_directory(args.ingest_dir)
print("Directory ingestion complete.")
if not any([args.ingest_pdf, args.ingest_file, args.ingest_dir]):
print("Hello from dnd-helpers!")
if not any([args.ingest_pdf, args.ingest_file, args.ingest_dir, args.run_pipeline]):
print("Hello from dnd-helpers! Use --help to see available commands.")
if __name__ == "__main__":
+2
View File
@@ -9,3 +9,5 @@ python-dotenv
llama-index
chromadb
pdfplumber
llama-index-embeddings-huggingface
llama-index-vector_stores-chroma
-4
View File
@@ -55,10 +55,6 @@ class ContextUpdate(BaseModel):
class FilterResult(BaseModel):
contextual_info: str = Field(
...,
description="Information interesting to the user but not useful for structured extraction",
)
filtered_text: str = Field(
..., description="Cleaned transcript used for structured data extraction"
)
+48 -33
View File
@@ -1,5 +1,6 @@
import logging
import os
import time
from posix import system
from this import s
from typing import Any, Dict, Optional
@@ -23,6 +24,7 @@ class LLMProcessor:
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: Optional[str] = None,
backend: Optional[str] = None,
):
"""
Initializes the LLMProcessor.
@@ -30,14 +32,16 @@ class LLMProcessor:
:param api_key: OpenAI API key. If None, it looks for OPENAI_API_KEY in environment variables.
:param base_url: OpenAI-compatible base URL (e.g., for vLLM).
:param model: The model to use for processing. If None, it looks for LLM_MODEL in environment variables.
:param backend: The LLM backend to use (openai, ollama, or vllm).
"""
backend = os.environ.get("LLM_BACKEND", "openai").lower()
# Use provided backend or fallback to environment variable
backend_env = backend or os.environ.get("LLM_BACKEND", "openai").lower()
if backend == "ollama":
if backend_env == "ollama":
# Ollama's OpenAI-compatible API
final_base_url = base_url or "http://localhost:11434/v1"
final_api_key = api_key or "ollama"
elif backend == "vllm":
elif backend_env == "vllm":
# Remote vLLM server
final_base_url = base_url or os.environ.get("OPENAI_BASE_URL")
final_api_key = api_key or os.environ.get("OPENAI_API_KEY")
@@ -45,22 +49,35 @@ class LLMProcessor:
final_base_url = base_url or os.environ.get("OPENAI_BASE_URL")
final_api_key = api_key or os.environ.get("OPENAI_API_KEY")
logger.info(f"Using LLM backend: {backend_env}")
try:
self.client = OpenAI(
api_key=final_api_key,
base_url=final_base_url,
)
# Simple connectivity check for local backends
if backend == "ollama":
if backend_env == "ollama":
# We can't easily check connectivity without making a call,
# but we can ensure the client is initialized.
pass
except Exception as e:
logger.error(f"Error initializing LLM client for backend {backend}: {e}")
logger.error(f"Error initializing LLM client for backend {backend_env}: {e}")
raise
self.model = model or os.environ.get("LLM_MODEL", "gpt-4o")
def _strip_markdown_code_blocks(self, content: str) -> str:
"""
Strips markdown code blocks (e.g., ```json ... ```) from the content.
"""
import re
# Remove opening and closing code blocks
content = re.sub(
r"^```(?:json)?\n?|```$", "", content, flags=re.MULTILINE
).strip()
return content
def _call_llm(
self,
system_prompt: str,
@@ -84,24 +101,37 @@ class LLMProcessor:
messages.append({"role": "user", "content": user_prompt})
# Debugging: Dump inputs
logger.debug("--- LLM CALL START ---")
logger.debug(f"Model: {self.model}")
logger.debug(f"Messages: {messages}")
if response_format:
logger.debug(f"Response Format: {response_format}")
logger.debug("--- LLM CALL END ---")
try:
start_time = time.perf_counter()
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
response_format=response_format,
extra_body={"enable_thinking": False},
extra_body={
"chat_template_kwargs": {
"enable_thinking": False
}
},
)
elapsed_time = time.perf_counter() - start_time
logger.info(f"LLM request completed in {elapsed_time:.2f}s")
content = response.choices[0].message.content
# Strip markdown code blocks if present
if content.startswith("```"):
import re
# Debugging: Dump outputs
logger.debug("--- LLM RESPONSE START ---")
logger.debug(f"Content: {content}")
logger.debug("--- LLM RESPONSE END ---")
content = re.sub(
r"^```(?:json)?\n?|```$", "", content, flags=re.MULTILINE
).strip()
return content
return self._strip_markdown_code_blocks(content)
except Exception as e:
logger.error(f"LLM Error: {e}")
return ""
@@ -147,34 +177,19 @@ class LLMProcessor:
"""
logger.info(f"LLM Processor (Extract): Calling extraction for: {filtered_text}")
try:
# Using standard chat.completions.create with JSON mode for better compatibility with vLLM
logger.info("LLM Processor (Extract): Sending request to backend...")
system_prompt = EXTRACTION_SYSTEM_PROMPT
if context:
system_prompt += f"\n{context}"
messages = [
{"role": "system", "content": system_prompt},
]
messages.append({"role": "user", "content": filtered_text})
for message in messages:
logger.info(f"LLM Processor (Extract): Message: {message}")
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
result = self._call_llm(
system_prompt=system_prompt,
user_prompt=filtered_text,
response_format={"type": "json_object"},
extra_body={"enable_thinking": False},
)
logger.info("LLM Processor (Extract): Response received from backend.")
import json
content = response.choices[0].message.content
logger.info(f"LLM Processor (Extract): Raw JSON response: {content}")
data = json.loads(content)
data = json.loads(result)
# Map the JSON data to the Pydantic model
return ExtractionResult(**data)
+3 -5
View File
@@ -12,7 +12,6 @@ NOISE_FILTER_SYSTEM_PROMPT = """
You are a D&D Game Master's assistant. Given a transcript, remove all out-of-character (OOC) chatter, logistical discussions (e.g., 'Where is my d20?'), and non-relevant noise.
You must output your response as a JSON object with the following keys:
- "contextual_info": Information that is interesting or relevant to the story/session but doesn't fit into lore, character state, or significant events (e.g., flavor text, atmospheric descriptions, player commentary that adds context).
- "filtered_text": The cleaned transcript. IMPORTANT: Keep all player questions, requests for rule clarifications, and mentions of spells, NPCs, or locations in this field, as they are used to trigger knowledge base lookups.
Keep the original speakers' names if they are present in the transcript.
@@ -22,15 +21,14 @@ Do not add any commentary or summaries. Just filter the text.
EXTRACTION_SYSTEM_PROMPT = """
You are a D&D session analyzer. Your goal is to extract structured data from a filtered transcript.
Extract any changes to character states (HP, status effects, inventory) and any new lore facts (NPCs, locations, world-building).
DO NOT THINK.
In addition extracting updates to character state and lore, look for the oppertunity to provide useful context,
such as the answer to a player's question or the resolution of a lore fact.
CONSTRAINTS:
- OUTPUT ONLY VALID JSON.
- DO NOT include any commentary, explanations, or "thought" blocks.
- DO NOT include any keys other than "lore", "character_state", and "events".
- If no relevant information is found, return empty lists for all keys.
- If a character name is not specified (e.g., "Your character"), use "Player Character".
- Do not repeat lore if it is already known; only provide new or updated facts.
Strict Output Format:
Return a JSON object with exactly these keys:
+112 -51
View File
@@ -23,7 +23,7 @@ from src.ui.tui import ConfirmationApp
# Configure logging to write to a file instead of stdout
logging.basicConfig(
level=logging.INFO,
level=logging.DEBUG,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[
logging.FileHandler("pipeline.log"),
@@ -39,14 +39,20 @@ logger = logging.getLogger(__name__)
class PipelineOrchestrator:
def __init__(self, loop: asyncio.AbstractEventLoop):
def __init__(
self,
loop: asyncio.AbstractEventLoop,
llm_config: Optional[dict] = None,
whisper_model: str = "base",
):
self.loop = loop
self.llm_config = llm_config or {}
# Modules
self.listener = AudioListener(loop=self.loop)
self.transcriber = Transcriber(model_size="small")
self.processor = LLMProcessor()
self.rag_manager = RAGManager()
self.transcriber = Transcriber(model_size=whisper_model, device="cuda")
self.processor = LLMProcessor(**self.llm_config)
self.rag_manager = RAGManager(llm_config=self.llm_config)
# Queues
self.stt_to_clean_queue = asyncio.Queue()
@@ -54,6 +60,10 @@ class PipelineOrchestrator:
self.clean_to_llm_queue = asyncio.Queue()
self.llm_to_ui_queue = asyncio.Queue()
self.log_queue = asyncio.Queue()
self.persistence_queue = asyncio.Queue()
# Synchronization
self.transcription_event = asyncio.Event()
self.is_running = False
@@ -82,11 +92,12 @@ class PipelineOrchestrator:
return f"Conversation History:\n{context_text}\n\n"
async def stt_worker(self):
async def stt_collector_worker(self):
"""
Worker that handles STT: Audio -> Text.
Worker that handles STT Collection: Audio -> Buffer.
This task is highly responsive and only manages the buffer.
"""
logger.info("STT Worker started.")
logger.info("STT Collector Worker started.")
while self.is_running:
try:
# Get audio chunk from listener
@@ -103,31 +114,68 @@ class PipelineOrchestrator:
):
self.audio_buffer.pop(0)
# Concatenate buffer for transcription
full_audio = np.concatenate(self.audio_buffer)
# Signal the transcription worker that new data is available
self.transcription_event.set()
# Transcribe (WhisperX now returns a list of (speaker, text, start, end))
results = self.transcriber.transcribe(full_audio)
except Exception as e:
logger.error(f"STT Collector Worker error: {e}")
# Filter for only new segments that start after the last processed segment
# Small sleep to prevent tight loop
await asyncio.sleep(0.01)
async def stt_transcription_worker(self):
"""
Worker that handles STT Transcription: Buffer -> Text.
This task handles the heavy lifting in a separate thread.
"""
logger.info("STT Transcription Worker started.")
while self.is_running:
try:
# Wait for a signal that new data is available
await self.transcription_event.wait()
self.transcription_event.clear()
# 1. Take a snapshot of the current buffer to avoid race conditions
# while the collector is appending new chunks.
buffer_snapshot = list(self.audio_buffer)
if not buffer_snapshot:
continue
# 2. Perform transcription in a separate thread.
# We pass the snapshot to the helper which handles concatenation and transcription.
results = await asyncio.to_thread(
self._transcribe_buffer_snapshot, buffer_snapshot
)
# 3. Filter for only new segments that start after the last processed segment
new_segments = [
res for res in results if res[2] >= self.last_processed_end_time
]
if new_segments:
for speaker, text, start, end in new_segments:
logger.info(f"Transcribed: [{speaker}] {text}")
logger.info(f"STT Raw Transcription: [{speaker}] {text}")
# Push raw transcription to log queue for UI visibility
await self.log_queue.put(f"[{speaker}] {text}")
await self.stt_to_clean_queue.put((speaker, text))
self.last_processed_end_time = max(
self.last_processed_end_time, end
)
except Exception as e:
logger.error(f"STT Worker error: {e}")
logger.error(f"STT Transcription Worker error: {e}")
# Small sleep to prevent tight loop if get_chunk is fast
# Small sleep to prevent tight loop
await asyncio.sleep(0.1)
def _transcribe_buffer_snapshot(self, buffer_snapshot):
"""
Helper method to be run in a thread.
Concatenates the buffer snapshot and transcribes it.
"""
full_audio = np.concatenate(buffer_snapshot)
return self.transcriber.transcribe(full_audio)
async def clean_worker(self):
"""
Worker that handles Text Cleaning: Raw STT -> Filtered Text.
@@ -184,8 +232,11 @@ class PipelineOrchestrator:
async def feed_ui():
while self.is_running:
try:
text = await self.ui_to_llm_queue.get()
await internal_queue.put(("UI", text))
item = await self.ui_to_llm_queue.get()
if isinstance(item, (LoreUpdate, CharacterStateUpdate)):
await self.persistence_queue.put(item)
else:
await internal_queue.put(("UI", item))
except Exception as e:
logger.error(f"LLM Feeder (UI) error: {e}")
@@ -197,6 +248,7 @@ class PipelineOrchestrator:
while self.is_running:
try:
logger.info("LLM Worker: Waiting for input...")
speaker, text = await internal_queue.get()
logger.info(f"LLM Worker: Processing text from {speaker}: {text}")
@@ -206,6 +258,9 @@ class PipelineOrchestrator:
# Log the text sent to the LLM for UI affordance
await self.log_queue.put(f"[{speaker}] {text}")
# Log the filtered message being sent to the LLM
logger.info(f"LLM Worker: Sending filtered message to LLM: {text}")
# Structured extraction using the processor
extraction_result = await asyncio.to_thread(
self.processor.extract_structured_data,
@@ -213,20 +268,8 @@ class PipelineOrchestrator:
context=context,
)
# Persistence: Lore Updates
for lore_update in extraction_result.lore_updates:
file_path = await asyncio.to_thread(update_lore, lore_update)
await asyncio.to_thread(self.rag_manager.ingest_file, file_path)
logger.info(
f"LLM Worker: Lore updated and ingested into RAG: {lore_update.entity_name}"
)
# Persistence: Character State Updates
for char_update in extraction_result.character_updates:
await asyncio.to_thread(update_character_state, char_update)
logger.info(
f"LLM Worker: Character {char_update.character_name} state updated."
)
# Send the entire result to UI for confirmation
await self.llm_to_ui_queue.put(extraction_result)
# UI Notification: Context Updates
for context_update in extraction_result.context_updates:
@@ -243,29 +286,32 @@ class PipelineOrchestrator:
for f in feeders:
f.cancel()
def _get_wiki_context(self) -> str:
async def persistence_worker(self):
"""
Reads all files in the lore directory and returns them as a 저희 context string.
Worker that handles persistence: Confirmed updates -> Disk & RAG.
"""
from src.persistence.lore import DATA_LORE_DIR
wiki_contents = []
# Recursively find all .md files in the lore directory
for path in DATA_LORE_DIR.rglob("*.md"):
logger.info("Persistence Worker started.")
while self.is_running:
try:
with open(path, "r", encoding="utf-8") as f:
content = f.read()
wiki_contents.append(
f"File: {path.relative_to(DATA_LORE_DIR)}\nContent:\n{content}"
update = await self.persistence_queue.get()
if isinstance(update, LoreUpdate):
file_path = await asyncio.to_thread(update_lore, update)
await asyncio.to_thread(self.rag_manager.ingest_file, file_path)
logger.info(
f"Persistence Worker: Lore updated and ingested into RAG: {update.entity_name}"
)
elif isinstance(update, CharacterStateUpdate):
await asyncio.to_thread(update_character_state, update)
logger.info(
f"Persistence Worker: Character {update.character_name} state updated."
)
except Exception as e:
logger.error(f"Error reading wiki file {path}: {e}")
return (
"\n\n".join(wiki_contents)
if wiki_contents
else "No wiki knowledge available."
)
if hasattr(self.persistence_queue, "task_done"):
self.persistence_queue.task_done()
except Exception as e:
logger.error(f"Persistence Worker error: {e}")
await asyncio.sleep(0.1)
async def tui_worker(self):
"""
@@ -303,9 +349,11 @@ class PipelineOrchestrator:
# Start workers as background tasks
tasks = [
asyncio.create_task(self.stt_worker()),
asyncio.create_task(self.stt_collector_worker()),
asyncio.create_task(self.stt_transcription_worker()),
asyncio.create_task(self.clean_worker()),
asyncio.create_task(self.llm_worker()),
asyncio.create_task(self.persistence_worker()),
asyncio.create_task(self.tui_worker()),
]
@@ -330,3 +378,16 @@ class PipelineOrchestrator:
Stops.
"""
self.is_running = False
if __name__ == "__main__":
import asyncio
async def main():
loop = asyncio.get_event_loop()
orchestrator = PipelineOrchestrator(loop)
try:
await orchestrator.run()
except KeyboardInterrupt:
orchestrator.stop()
asyncio.run(main())
+3 -2
View File
@@ -12,8 +12,9 @@ from src.llm.processor import LLMProcessor
class RAGManager:
def __init__(self, persist_dir: str = "data/rag_index"):
def __init__(self, persist_dir: str = "data/rag_index", llm_config: Optional[dict] = None):
self.persist_dir = persist_dir
self.llm_config = llm_config or {}
self.db = chromadb.PersistentClient(path=self.persist_dir)
self.collection_name = "phb_collection"
@@ -110,7 +111,7 @@ class RAGManager:
if not nodes:
return []
processor = LLMProcessor()
processor = LLMProcessor(**self.llm_config)
# Construct the context from retrieved nodes
context_text = "\n\n".join(
+9 -36
View File
@@ -1,9 +1,7 @@
import logging
import os
import numpy as np
import whisperx
from whisperx.diarize import DiarizationPipeline
# Do not call basicConfig here, as it's called in the orchestrator
logger = logging.getLogger(__name__)
@@ -11,14 +9,14 @@ logger = logging.getLogger(__name__)
class Transcriber:
"""
Converts audio chunks (numpy arrays) into text and identifies speakers using WhisperX.
Converts audio chunks (numpy arrays) into text using WhisperX.
"""
def __init__(
self, model_size="base", device="cpu", compute_type="int8", language="en"
):
"""
Initializes the WhisperX model and diarization pipeline.
Initializes the WhisperX model.
Args:
model_size (str): The size of the model to use (e.g., "tiny", "base", "small").
@@ -39,28 +37,20 @@ class Transcriber:
model_size, device=device, compute_type=compute_type
)
# Load alignment model (required for accurate speaker assignment)
# model_dir=None allows automatic model selection based on the language
self.align_model, self.align_metadata = whisperx.load_align_model(
device=device, model_dir=None, language_code=self.language
)
self.diarize_model = DiarizationPipeline()
logger.info("WhisperX and Diarization models loaded successfully.")
logger.info("WhisperX model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load WhisperX models: {e}")
raise
def transcribe(self, audio_chunk):
"""
Transcribes an audio chunk and performs speaker diarization.
Transcribes an audio chunk.
Args:
audio_chunk (np.ndarray): The audio data as a numpy array.
Returns:
list: A list of tuples (speaker_id, text).
list: A list of tuples (speaker_id, text, start, end).
"""
if audio_chunk is None:
return []
@@ -73,35 +63,18 @@ class Transcriber:
# batch_size is set to 16 for efficiency; can be adjusted based on VRAM
result = self.model.transcribe(audio, batch_size=16)
# 2. Perform alignment
# Alignment is necessary for the assign_words_to_speakers step
result_a = whisperx.align(
result["segments"],
self.align_model,
self.align_metadata,
audio,
self.device,
)
# 3. Perform diarization
diarize_segments = self.diarize_model(audio)
# 4. Align transcription segments with speakers
result_final = whisperx.assign_word_speakers(diarize_segments, result_a)
# Extract (speaker_id, text, start, end) tuples from the final result
# Extract ("Unknown", text, start, end) tuples from the transcription result
output = []
for segment in result_final.get("segments", []):
speaker = segment.get("speaker", "Unknown")
for segment in result.get("segments", []):
text = segment.get("text", "").strip()
start = segment.get("start", 0.0)
end = segment.get("end", 0.0)
if text:
output.append((speaker, text, start, end))
output.append(("Unknown", text, start, end))
return output
except Exception as e:
logger.error(f"Transcription/Diarization error: {e}")
logger.error(f"Transcription error: {e}")
return []
def close(self):
+42 -37
View File
@@ -16,20 +16,26 @@ from textual.widgets import (
Static,
)
from src.llm.models import CharacterStateUpdate, ExtractionResult, LoreUpdate
from src.llm.models import CharacterStateUpdate, ContextUpdate, ExtractionResult, LoreUpdate
from src.persistence.characters import update_character_state
from src.persistence.lore import update_lore
class EditModal(ModalScreen):
def __init__(self, initial_text: str, on_save: callable):
def __init__(self, initial_text: str, initial_type: str, initial_target: str, on_save: callable):
super().__init__()
self.initial_text = initial_text
self.initial_type = initial_type
self.initial_target = initial_target
self.on_save = on_save
def compose(self) -> ComposeResult:
with Vertical(id="modal-container"):
yield Label("Edit Fact Content:")
yield Label("Type:")
yield Input(value=self.initial_type, id="edit-type")
yield Label("Target:")
yield Input(value=self.initial_target, id="edit-target")
yield Label("Content:")
yield Input(value=self.initial_text, id="edit-input")
with Horizontal(id="modal-actions"):
yield Button("Save", id="btn-save")
@@ -38,7 +44,9 @@ class EditModal(ModalScreen):
def on_button_pressed(self, event: Button.Pressed) -> None:
if event.button.id == "btn-save":
edit_input = self.query_one("#edit-input", Input)
self.on_save(edit_input.value)
type_input = self.query_one("#edit-type", Input)
target_input = self.query_one("#edit-target", Input)
self.on_save(edit_input.value, type_input.value, target_input.value)
self.dismiss()
elif event.button.id == "btn-cancel":
self.dismiss()
@@ -68,32 +76,27 @@ class ConfirmationApp(App):
}
#pending-facts-table {
height: 40%;
height: 30%;
border: solid white;
}
#llm-input-container {
height: 10%;
border: solid white;
padding: 1;
padding: 0;
}
#context-pane {
height: 50%;
height: 60%;
border: solid white;
}
#log-pane {
height: 30%;
height: 100%;
border: solid white;
background: #111;
}
#log-footer {
height: 70%;
border: solid white;
}
#modal-container {
width: 60%;
height: auto;
@@ -109,7 +112,7 @@ class ConfirmationApp(App):
align: right middle;
}
#edit-input {
#edit-input, #edit-type, #edit-target {
margin: 1 0;
}
@@ -155,18 +158,11 @@ class ConfirmationApp(App):
Horizontal(
Vertical(
DataTable(id="pending-facts-table"),
Vertical(
Input(placeholder="Message LLM...", id="llm-input"),
id="llm-input-container",
),
Input(placeholder="Message LLM...", id="llm-input"),
ListView(id="context-pane"),
id="left-pane",
),
Vertical(
ListView(id="log-pane"),
Static("LATEST LLM INPUTS", id="log-footer"),
id="right-pane",
),
ListView(id="log-pane"),
id="content-wrapper",
),
id="main-container",
@@ -199,7 +195,7 @@ class ConfirmationApp(App):
table = self.query_one("#pending-facts-table", DataTable)
if isinstance(update, LoreUpdate):
table.add_row(
"Lore", update.entity_name or "General", update.content, key=str(index)
update.category, update.entity_name or "General", update.content, key=str(index)
)
elif isinstance(update, CharacterStateUpdate):
change_text = f"HP: {update.hp_change or 0}"
@@ -213,12 +209,13 @@ class ConfirmationApp(App):
while True:
try:
update = await self.llm_to_ui_queue.get()
display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}"
context_list = self.query_one("#context-pane", ListView)
# ListView.insert takes an *iterable* of ListItems; passing a
# bare ListItem raises TypeError because ListItem is not iterable.
# Insert at the top to show most recent first.
await context_list.insert(0, [ListItem(Static(display_text))])
if isinstance(update, ExtractionResult):
self.handle_proposal_result(update)
elif isinstance(update, ContextUpdate):
display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}"
context_list = self.query_one("#context-pane", ListView)
# Insert at the top to show most recent first.
await context_list.insert(0, [ListItem(Static(display_text))])
if hasattr(self.llm_to_ui_queue, "task_done"):
self.llm_to_ui_queue.task_done()
except Exception as e:
@@ -263,17 +260,15 @@ class ConfirmationApp(App):
self.ui_to_llm_queue.put_nowait(text)
input_widget.value = ""
def action_accept(self) -> None:
async def action_accept(self) -> None:
table = self.query_one("#pending-facts-table", DataTable)
row_index = table.cursor_row
if row_index < 0 or row_index >= len(self.pending_updates):
return
update = self.pending_updates[row_index]
if isinstance(update, LoreUpdate):
update_lore(update)
elif isinstance(update, CharacterStateUpdate):
update_character_state(update)
if self.ui_to_llm_queue:
self.ui_to_llm_queue.put_nowait(update)
self.remove_update(row_index)
@@ -293,24 +288,34 @@ class ConfirmationApp(App):
update = self.pending_updates[row_index]
initial_text = ""
initial_type = ""
initial_target = ""
if isinstance(update, LoreUpdate):
initial_text = update.content
initial_type = update.category
initial_target = update.entity_name or ""
elif isinstance(update, CharacterStateUpdate):
initial_text = str(update.hp_change or 0)
initial_type = "Char"
initial_target = update.character_name
def save_callback(new_text: str):
def save_callback(new_text: str, new_type: str, new_target: str):
if isinstance(update, LoreUpdate):
update.content = new_text
update.category = new_type
update.entity_name = new_target if new_target else None
elif isinstance(update, CharacterStateUpdate):
try:
update.hp_change = int(new_text)
except ValueError:
pass
update.character_name = new_target
# Update the table
self.refresh_table()
self.push_screen(EditModal(initial_text, save_callback))
self.push_screen(EditModal(initial_text, initial_type, initial_target, save_callback))
def remove_update(self, index: int) -> None:
del self.pending_updates[index]