Compare commits

..

6 Commits

Author SHA1 Message Date
charles 284c50acd8 refactor(stt): remove speaker identification (diarization) from transcriber
Removes the speaker diarization pipeline and alignment model from the STT module to reduce resource usage and complexity.
The transcription API remains compatible by returning 'Unknown' as the speaker ID for all transcribed segments.

- Removed DiarizationPipeline and align_model from Transcriber
- Simplified transcribe method to return basic transcription segments
- Updated logging and docstrings to reflect changes
2026-06-06 20:52:04 -07:00
charles 01b049cf37 Update main.py 2026-06-05 23:10:39 -07:00
charles da5ab1bb44 Refactor STT pipeline and CLI documentation
Split the STT worker into a collector and a transcription worker
to offload heavy processing to a background thread. Add the
`--whisper-model` flag and implement LLM latency logging. Expand
the README with comprehensive CLI usage instructions.
2026-05-31 15:04:41 -07:00
charles 71ecdb3468 Add LLM configuration and pipeline execution 2026-05-31 14:13:58 -07:00
charles 2858c7e235 Update tui.py 2026-05-30 23:03:04 -07:00
charles 15dfbfb467 Add LLM backend support and improve debugging observability
- Add LLM_BACKEND to environment configuration
- Implement detailed debug logging for LLM request/response cycles
- Add missing llama-index dependencies for embeddings and chroma
- Update prompt constraints to prevent lore redundancy
- Enable CUDA for transcription and set logging to DEBUG level
- Add entry point for running the orchestrator directly
- Cleanup unused comment in TUI context updates
2026-05-28 23:06:25 -07:00
10 changed files with 280 additions and 94 deletions
+1
View File
@@ -2,6 +2,7 @@
OPENAI_API_KEY=no-key-required
OPENAI_BASE_URL=https://vllm.tipsy.codes/v1
LLM_MODEL=google/gemma-4-26b-a4b-it
LLM_BACKEND=vllm
#LLM_BACKEND=ollama
#LLM_MODEL=gemma:2b
WHISPER_MODEL=base
+48 -2
View File
@@ -28,8 +28,54 @@ Distill long sessions into concise highlights. Use LLMs to summarize recorded tr
## Interface & Usage
- **CLI**: The primary interface for confirming automated updates and querying current game state.
- **Text Editors**: Since data is stored in Markdown and JSON, you can use any editor (VS Code, Vim, Obsidian) to manually refine your campaign data.
### CLI
The primary interface for confirming automated updates and querying current game state.
#### Command Line Arguments
Use these flags to manage data ingestion and run the live capture pipeline.
##### RAG Ingestion
Use these flags to add external documents to the RAG (Retrieval-Augmented Generation) system.
| Flag | Description |
| :--- | :--- |
| `--ingest-pdf <path>` | Path to a PDF file to ingest |
| `--ingest-file <path>` | Path to a markdown file to ingest |
| `--ingest-dir <path>` | Path to a directory of markdown files to ingest |
##### LLM Configuration
These flags allow you to override the environment variables for the LLM backend.
| Flag | Description |
| :--- | :--- |
| `--llm-backend <backend>` | Backend to use (`openai`, `ollama`, or `vllm`) |
| `--llm-model <model>` | The model name to use |
| `--llm-api-key <key>` | API key for the LLM backend |
| `--llm-base-url <url>` | Base URL for the LLM backend |
##### Pipeline Execution
| Flag | Description |
| :--- | :--- |
| `--run-pipeline` | Starts the main orchestration pipeline (TUI + STT + LLM) |
##### Example Command
To run the live orchestration pipeline using the configuration specified in your `env.sh`, you can use:
```bash
python main.py --run-pipeline \
--llm-backend vllm \
--llm-model google/gemma-4-26b-a4b-it \
--llm-api-key no-key-required \
--whisper-model medium \
--llm-base-url https://vllm.tipsy.codes/v1
```
### Text Editors
Since data is stored in Markdown and JSON, you can use any editor (VS Code, Vim, Obsidian) to manually refine your campaign data.
## Technical Stack
+72 -3
View File
@@ -1,10 +1,15 @@
import argparse
import asyncio
import os
from src.pipeline.orchestrator import PipelineOrchestrator
from src.rag.manager import RAGManager
def main():
parser = argparse.ArgumentParser(description="D&D Helpers CLI")
# RAG Ingestion Arguments
parser.add_argument(
"--ingest-pdf",
type=str,
@@ -21,9 +26,73 @@ def main():
help="Path to a directory of markdown files to ingest into the RAG system",
)
# LLM Configuration Arguments
parser.add_argument(
"--llm-backend",
type=str,
choices=["openai", "ollama", "vllm"],
default=os.environ.get("LLM_BACKEND", "openai"),
help="LLM backend to use",
)
parser.add_argument(
"--llm-model",
type=str,
default=os.environ.get("LLM_MODEL", "gpt-4o"),
help="The model to use for processing",
)
parser.add_argument(
"--llm-api-key",
type=str,
default=os.environ.get("OPENAI_API_KEY"),
help="API key for the LLM backend",
)
parser.add_argument(
"--llm-base-url",
type=str,
default=os.environ.get("OPENAI_BASE_URL"),
help="Base URL for the LLM backend",
)
# STT Configuration Arguments
parser.add_argument(
"--whisper-model",
type=str,
default=os.environ.get("WHISPER_MODEL", "turbo"),
help="The Whisper model to use for STT",
)
# Pipeline Execution Argument
parser.add_argument(
"--run-pipeline",
action="store_true",
help="Start the main orchestration pipeline (TUI + STT + LLM)",
)
args = parser.parse_args()
rag_manager = RAGManager()
llm_config = {
"backend": args.llm_backend,
"model": args.llm_model,
"api_key": args.llm_api_key,
"base_url": args.llm_base_url,
}
# Remove None values to allow defaults to take over if not provided
llm_config = {k: v for k, v in llm_config.items() if v is not None}
if args.run_pipeline:
async def run_pipeline():
loop = asyncio.get_event_loop()
orchestrator = PipelineOrchestrator(loop, llm_config=llm_config, whisper_model=args.whisper_model)
try:
await orchestrator.run()
except KeyboardInterrupt:
orchestrator.stop()
asyncio.run(run_pipeline())
return
rag_manager = RAGManager(llm_config=llm_config)
if args.ingest_pdf:
print(f"Ingesting PDF: {args.ingest_pdf}...")
@@ -40,8 +109,8 @@ def main():
rag_manager.ingest_directory(args.ingest_dir)
print("Directory ingestion complete.")
if not any([args.ingest_pdf, args.ingest_file, args.ingest_dir]):
print("Hello from dnd-helpers!")
if not any([args.ingest_pdf, args.ingest_file, args.ingest_dir, args.run_pipeline]):
print("Hello from dnd-helpers! Use --help to see available commands.")
if __name__ == "__main__":
+2
View File
@@ -9,3 +9,5 @@ python-dotenv
llama-index
chromadb
pdfplumber
llama-index-embeddings-huggingface
llama-index-vector_stores-chroma
+32 -6
View File
@@ -1,5 +1,6 @@
import logging
import os
import time
from posix import system
from this import s
from typing import Any, Dict, Optional
@@ -23,6 +24,7 @@ class LLMProcessor:
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: Optional[str] = None,
backend: Optional[str] = None,
):
"""
Initializes the LLMProcessor.
@@ -30,14 +32,16 @@ class LLMProcessor:
:param api_key: OpenAI API key. If None, it looks for OPENAI_API_KEY in environment variables.
:param base_url: OpenAI-compatible base URL (e.g., for vLLM).
:param model: The model to use for processing. If None, it looks for LLM_MODEL in environment variables.
:param backend: The LLM backend to use (openai, ollama, or vllm).
"""
backend = os.environ.get("LLM_BACKEND", "openai").lower()
# Use provided backend or fallback to environment variable
backend_env = backend or os.environ.get("LLM_BACKEND", "openai").lower()
if backend == "ollama":
if backend_env == "ollama":
# Ollama's OpenAI-compatible API
final_base_url = base_url or "http://localhost:11434/v1"
final_api_key = api_key or "ollama"
elif backend == "vllm":
elif backend_env == "vllm":
# Remote vLLM server
final_base_url = base_url or os.environ.get("OPENAI_BASE_URL")
final_api_key = api_key or os.environ.get("OPENAI_API_KEY")
@@ -45,18 +49,19 @@ class LLMProcessor:
final_base_url = base_url or os.environ.get("OPENAI_BASE_URL")
final_api_key = api_key or os.environ.get("OPENAI_API_KEY")
logger.info(f"Using LLM backend: {backend_env}")
try:
self.client = OpenAI(
api_key=final_api_key,
base_url=final_base_url,
)
# Simple connectivity check for local backends
if backend == "ollama":
if backend_env == "ollama":
# We can't easily check connectivity without making a call,
# but we can ensure the client is initialized.
pass
except Exception as e:
logger.error(f"Error initializing LLM client for backend {backend}: {e}")
logger.error(f"Error initializing LLM client for backend {backend_env}: {e}")
raise
self.model = model or os.environ.get("LLM_MODEL", "gpt-4o")
@@ -96,15 +101,36 @@ class LLMProcessor:
messages.append({"role": "user", "content": user_prompt})
# Debugging: Dump inputs
logger.debug("--- LLM CALL START ---")
logger.debug(f"Model: {self.model}")
logger.debug(f"Messages: {messages}")
if response_format:
logger.debug(f"Response Format: {response_format}")
logger.debug("--- LLM CALL END ---")
try:
start_time = time.perf_counter()
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
response_format=response_format,
extra_body={"enable_thinking": False},
extra_body={
"chat_template_kwargs": {
"enable_thinking": False
}
},
)
elapsed_time = time.perf_counter() - start_time
logger.info(f"LLM request completed in {elapsed_time:.2f}s")
content = response.choices[0].message.content
# Debugging: Dump outputs
logger.debug("--- LLM RESPONSE START ---")
logger.debug(f"Content: {content}")
logger.debug("--- LLM RESPONSE END ---")
return self._strip_markdown_code_blocks(content)
except Exception as e:
logger.error(f"LLM Error: {e}")
+1
View File
@@ -28,6 +28,7 @@ CONSTRAINTS:
- OUTPUT ONLY VALID JSON.
- If no relevant information is found, return empty lists for all keys.
- If a character name is not specified (e.g., "Your character"), use "Player Character".
- Do not repeat lore if it is already known; only provide new or updated facts.
Strict Output Format:
Return a JSON object with exactly these keys:
+80 -17
View File
@@ -23,7 +23,7 @@ from src.ui.tui import ConfirmationApp
# Configure logging to write to a file instead of stdout
logging.basicConfig(
level=logging.INFO,
level=logging.DEBUG,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[
logging.FileHandler("pipeline.log"),
@@ -39,14 +39,20 @@ logger = logging.getLogger(__name__)
class PipelineOrchestrator:
def __init__(self, loop: asyncio.AbstractEventLoop):
def __init__(
self,
loop: asyncio.AbstractEventLoop,
llm_config: Optional[dict] = None,
whisper_model: str = "base",
):
self.loop = loop
self.llm_config = llm_config or {}
# Modules
self.listener = AudioListener(loop=self.loop)
self.transcriber = Transcriber(model_size="base")
self.processor = LLMProcessor()
self.rag_manager = RAGManager()
self.transcriber = Transcriber(model_size=whisper_model, device="cuda")
self.processor = LLMProcessor(**self.llm_config)
self.rag_manager = RAGManager(llm_config=self.llm_config)
# Queues
self.stt_to_clean_queue = asyncio.Queue()
@@ -56,6 +62,9 @@ class PipelineOrchestrator:
self.log_queue = asyncio.Queue()
self.persistence_queue = asyncio.Queue()
# Synchronization
self.transcription_event = asyncio.Event()
self.is_running = False
# Conversation history for context
@@ -83,11 +92,12 @@ class PipelineOrchestrator:
return f"Conversation History:\n{context_text}\n\n"
async def stt_worker(self):
async def stt_collector_worker(self):
"""
Worker that handles STT: Audio -> Text.
Worker that handles STT Collection: Audio -> Buffer.
This task is highly responsive and only manages the buffer.
"""
logger.info("STT Worker started.")
logger.info("STT Collector Worker started.")
while self.is_running:
try:
# Get audio chunk from listener
@@ -104,33 +114,68 @@ class PipelineOrchestrator:
):
self.audio_buffer.pop(0)
# Concatenate buffer for transcription
full_audio = np.concatenate(self.audio_buffer)
# Signal the transcription worker that new data is available
self.transcription_event.set()
# Transcribe (WhisperX now returns a list of (speaker, text, start, end))
except Exception as e:
logger.error(f"STT Collector Worker error: {e}")
# Small sleep to prevent tight loop
await asyncio.sleep(0.01)
async def stt_transcription_worker(self):
"""
Worker that handles STT Transcription: Buffer -> Text.
This task handles the heavy lifting in a separate thread.
"""
logger.info("STT Transcription Worker started.")
while self.is_running:
try:
# Wait for a signal that new data is available
await self.transcription_event.wait()
self.transcription_event.clear()
# 1. Take a snapshot of the current buffer to avoid race conditions
# while the collector is appending new chunks.
buffer_snapshot = list(self.audio_buffer)
if not buffer_snapshot:
continue
# 2. Perform transcription in a separate thread.
# We pass the snapshot to the helper which handles concatenation and transcription.
results = await asyncio.to_thread(
self.transcriber.transcribe, full_audio
self._transcribe_buffer_snapshot, buffer_snapshot
)
# Filter for only new segments that start after the last processed segment
# 3. Filter for only new segments that start after the last processed segment
new_segments = [
res for res in results if res[2] >= self.last_processed_end_time
]
if new_segments:
for speaker, text, start, end in new_segments:
logger.info(f"Transcribed: [{speaker}] {text}")
logger.info(f"STT Raw Transcription: [{speaker}] {text}")
# Push raw transcription to log queue for UI visibility
await self.log_queue.put(f"[{speaker}] {text}")
await self.stt_to_clean_queue.put((speaker, text))
self.last_processed_end_time = max(
self.last_processed_end_time, end
)
except Exception as e:
logger.error(f"STT Worker error: {e}")
logger.error(f"STT Transcription Worker error: {e}")
# Small sleep to prevent tight loop if get_chunk is fast
# Small sleep to prevent tight loop
await asyncio.sleep(0.1)
def _transcribe_buffer_snapshot(self, buffer_snapshot):
"""
Helper method to be run in a thread.
Concatenates the buffer snapshot and transcribes it.
"""
full_audio = np.concatenate(buffer_snapshot)
return self.transcriber.transcribe(full_audio)
async def clean_worker(self):
"""
Worker that handles Text Cleaning: Raw STT -> Filtered Text.
@@ -203,6 +248,7 @@ class PipelineOrchestrator:
while self.is_running:
try:
logger.info("LLM Worker: Waiting for input...")
speaker, text = await internal_queue.get()
logger.info(f"LLM Worker: Processing text from {speaker}: {text}")
@@ -212,6 +258,9 @@ class PipelineOrchestrator:
# Log the text sent to the LLM for UI affordance
await self.log_queue.put(f"[{speaker}] {text}")
# Log the filtered message being sent to the LLM
logger.info(f"LLM Worker: Sending filtered message to LLM: {text}")
# Structured extraction using the processor
extraction_result = await asyncio.to_thread(
self.processor.extract_structured_data,
@@ -300,7 +349,8 @@ class PipelineOrchestrator:
# Start workers as background tasks
tasks = [
asyncio.create_task(self.stt_worker()),
asyncio.create_task(self.stt_collector_worker()),
asyncio.create_task(self.stt_transcription_worker()),
asyncio.create_task(self.clean_worker()),
asyncio.create_task(self.llm_worker()),
asyncio.create_task(self.persistence_worker()),
@@ -328,3 +378,16 @@ class PipelineOrchestrator:
Stops.
"""
self.is_running = False
if __name__ == "__main__":
import asyncio
async def main():
loop = asyncio.get_event_loop()
orchestrator = PipelineOrchestrator(loop)
try:
await orchestrator.run()
except KeyboardInterrupt:
orchestrator.stop()
asyncio.run(main())
+3 -2
View File
@@ -12,8 +12,9 @@ from src.llm.processor import LLMProcessor
class RAGManager:
def __init__(self, persist_dir: str = "data/rag_index"):
def __init__(self, persist_dir: str = "data/rag_index", llm_config: Optional[dict] = None):
self.persist_dir = persist_dir
self.llm_config = llm_config or {}
self.db = chromadb.PersistentClient(path=self.persist_dir)
self.collection_name = "phb_collection"
@@ -110,7 +111,7 @@ class RAGManager:
if not nodes:
return []
processor = LLMProcessor()
processor = LLMProcessor(**self.llm_config)
# Construct the context from retrieved nodes
context_text = "\n\n".join(
+9 -36
View File
@@ -1,9 +1,7 @@
import logging
import os
import numpy as np
import whisperx
from whisperx.diarize import DiarizationPipeline
# Do not call basicConfig here, as it's called in the orchestrator
logger = logging.getLogger(__name__)
@@ -11,14 +9,14 @@ logger = logging.getLogger(__name__)
class Transcriber:
"""
Converts audio chunks (numpy arrays) into text and identifies speakers using WhisperX.
Converts audio chunks (numpy arrays) into text using WhisperX.
"""
def __init__(
self, model_size="base", device="cpu", compute_type="int8", language="en"
):
"""
Initializes the WhisperX model and diarization pipeline.
Initializes the WhisperX model.
Args:
model_size (str): The size of the model to use (e.g., "tiny", "base", "small").
@@ -39,28 +37,20 @@ class Transcriber:
model_size, device=device, compute_type=compute_type
)
# Load alignment model (required for accurate speaker assignment)
# model_dir=None allows automatic model selection based on the language
self.align_model, self.align_metadata = whisperx.load_align_model(
device=device, model_dir=None, language_code=self.language
)
self.diarize_model = DiarizationPipeline()
logger.info("WhisperX and Diarization models loaded successfully.")
logger.info("WhisperX model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load WhisperX models: {e}")
raise
def transcribe(self, audio_chunk):
"""
Transcribes an audio chunk and performs speaker diarization.
Transcribes an audio chunk.
Args:
audio_chunk (np.ndarray): The audio data as a numpy array.
Returns:
list: A list of tuples (speaker_id, text).
list: A list of tuples (speaker_id, text, start, end).
"""
if audio_chunk is None:
return []
@@ -73,35 +63,18 @@ class Transcriber:
# batch_size is set to 16 for efficiency; can be adjusted based on VRAM
result = self.model.transcribe(audio, batch_size=16)
# 2. Perform alignment
# Alignment is necessary for the assign_words_to_speakers step
result_a = whisperx.align(
result["segments"],
self.align_model,
self.align_metadata,
audio,
self.device,
)
# 3. Perform diarization
diarize_segments = self.diarize_model(audio)
# 4. Align transcription segments with speakers
result_final = whisperx.assign_word_speakers(diarize_segments, result_a)
# Extract (speaker_id, text, start, end) tuples from the final result
# Extract ("Unknown", text, start, end) tuples from the transcription result
output = []
for segment in result_final.get("segments", []):
speaker = segment.get("speaker", "Unknown")
for segment in result.get("segments", []):
text = segment.get("text", "").strip()
start = segment.get("start", 0.0)
end = segment.get("end", 0.0)
if text:
output.append((speaker, text, start, end))
output.append(("Unknown", text, start, end))
return output
except Exception as e:
logger.error(f"Transcription/Diarization error: {e}")
logger.error(f"Transcription error: {e}")
return []
def close(self):
+30 -26
View File
@@ -16,20 +16,26 @@ from textual.widgets import (
Static,
)
from src.llm.models import CharacterStateUpdate, ExtractionResult, LoreUpdate
from src.llm.models import CharacterStateUpdate, ContextUpdate, ExtractionResult, LoreUpdate
from src.persistence.characters import update_character_state
from src.persistence.lore import update_lore
class EditModal(ModalScreen):
def __init__(self, initial_text: str, on_save: callable):
def __init__(self, initial_text: str, initial_type: str, initial_target: str, on_save: callable):
super().__init__()
self.initial_text = initial_text
self.initial_type = initial_type
self.initial_target = initial_target
self.on_save = on_save
def compose(self) -> ComposeResult:
with Vertical(id="modal-container"):
yield Label("Edit Fact Content:")
yield Label("Type:")
yield Input(value=self.initial_type, id="edit-type")
yield Label("Target:")
yield Input(value=self.initial_target, id="edit-target")
yield Label("Content:")
yield Input(value=self.initial_text, id="edit-input")
with Horizontal(id="modal-actions"):
yield Button("Save", id="btn-save")
@@ -38,7 +44,9 @@ class EditModal(ModalScreen):
def on_button_pressed(self, event: Button.Pressed) -> None:
if event.button.id == "btn-save":
edit_input = self.query_one("#edit-input", Input)
self.on_save(edit_input.value)
type_input = self.query_one("#edit-type", Input)
target_input = self.query_one("#edit-target", Input)
self.on_save(edit_input.value, type_input.value, target_input.value)
self.dismiss()
elif event.button.id == "btn-cancel":
self.dismiss()
@@ -68,32 +76,27 @@ class ConfirmationApp(App):
}
#pending-facts-table {
height: 40%;
height: 30%;
border: solid white;
}
#llm-input-container {
height: 10%;
border: solid white;
padding: 1;
padding: 0;
}
#context-pane {
height: 50%;
height: 60%;
border: solid white;
}
#log-pane {
height: 30%;
height: 100%;
border: solid white;
background: #111;
}
#log-footer {
height: 70%;
border: solid white;
}
#modal-container {
width: 60%;
height: auto;
@@ -109,7 +112,7 @@ class ConfirmationApp(App):
align: right middle;
}
#edit-input {
#edit-input, #edit-type, #edit-target {
margin: 1 0;
}
@@ -155,18 +158,11 @@ class ConfirmationApp(App):
Horizontal(
Vertical(
DataTable(id="pending-facts-table"),
Vertical(
Input(placeholder="Message LLM...", id="llm-input"),
id="llm-input-container",
),
ListView(id="context-pane"),
id="left-pane",
),
Vertical(
ListView(id="log-pane"),
Static("LATEST LLM INPUTS", id="log-footer"),
id="right-pane",
),
id="content-wrapper",
),
id="main-container",
@@ -199,7 +195,7 @@ class ConfirmationApp(App):
table = self.query_one("#pending-facts-table", DataTable)
if isinstance(update, LoreUpdate):
table.add_row(
"Lore", update.entity_name or "General", update.content, key=str(index)
update.category, update.entity_name or "General", update.content, key=str(index)
)
elif isinstance(update, CharacterStateUpdate):
change_text = f"HP: {update.hp_change or 0}"
@@ -218,8 +214,6 @@ class ConfirmationApp(App):
elif isinstance(update, ContextUpdate):
display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}"
context_list = self.query_one("#context-pane", ListView)
# ListView.insert takes an *iterable* of ListItems; passing a
# bare ListItem raises TypeError because ListItem is not iterable.
# Insert at the top to show most recent first.
await context_list.insert(0, [ListItem(Static(display_text))])
if hasattr(self.llm_to_ui_queue, "task_done"):
@@ -294,24 +288,34 @@ class ConfirmationApp(App):
update = self.pending_updates[row_index]
initial_text = ""
initial_type = ""
initial_target = ""
if isinstance(update, LoreUpdate):
initial_text = update.content
initial_type = update.category
initial_target = update.entity_name or ""
elif isinstance(update, CharacterStateUpdate):
initial_text = str(update.hp_change or 0)
initial_type = "Char"
initial_target = update.character_name
def save_callback(new_text: str):
def save_callback(new_text: str, new_type: str, new_target: str):
if isinstance(update, LoreUpdate):
update.content = new_text
update.category = new_type
update.entity_name = new_target if new_target else None
elif isinstance(update, CharacterStateUpdate):
try:
update.hp_change = int(new_text)
except ValueError:
pass
update.character_name = new_target
# Update the table
self.refresh_table()
self.push_screen(EditModal(initial_text, save_callback))
self.push_screen(EditModal(initial_text, initial_type, initial_target, save_callback))
def remove_update(self, index: int) -> None:
del self.pending_updates[index]