Compare commits
6 Commits
49127d695a
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 284c50acd8 | |||
| 01b049cf37 | |||
| da5ab1bb44 | |||
| 71ecdb3468 | |||
| 2858c7e235 | |||
| 15dfbfb467 |
@@ -2,6 +2,7 @@
|
|||||||
OPENAI_API_KEY=no-key-required
|
OPENAI_API_KEY=no-key-required
|
||||||
OPENAI_BASE_URL=https://vllm.tipsy.codes/v1
|
OPENAI_BASE_URL=https://vllm.tipsy.codes/v1
|
||||||
LLM_MODEL=google/gemma-4-26b-a4b-it
|
LLM_MODEL=google/gemma-4-26b-a4b-it
|
||||||
|
LLM_BACKEND=vllm
|
||||||
#LLM_BACKEND=ollama
|
#LLM_BACKEND=ollama
|
||||||
#LLM_MODEL=gemma:2b
|
#LLM_MODEL=gemma:2b
|
||||||
WHISPER_MODEL=base
|
WHISPER_MODEL=base
|
||||||
|
|||||||
@@ -28,8 +28,54 @@ Distill long sessions into concise highlights. Use LLMs to summarize recorded tr
|
|||||||
|
|
||||||
## Interface & Usage
|
## Interface & Usage
|
||||||
|
|
||||||
- **CLI**: The primary interface for confirming automated updates and querying current game state.
|
### CLI
|
||||||
- **Text Editors**: Since data is stored in Markdown and JSON, you can use any editor (VS Code, Vim, Obsidian) to manually refine your campaign data.
|
|
||||||
|
The primary interface for confirming automated updates and querying current game state.
|
||||||
|
|
||||||
|
#### Command Line Arguments
|
||||||
|
|
||||||
|
Use these flags to manage data ingestion and run the live capture pipeline.
|
||||||
|
|
||||||
|
##### RAG Ingestion
|
||||||
|
Use these flags to add external documents to the RAG (Retrieval-Augmented Generation) system.
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
| :--- | :--- |
|
||||||
|
| `--ingest-pdf <path>` | Path to a PDF file to ingest |
|
||||||
|
| `--ingest-file <path>` | Path to a markdown file to ingest |
|
||||||
|
| `--ingest-dir <path>` | Path to a directory of markdown files to ingest |
|
||||||
|
|
||||||
|
##### LLM Configuration
|
||||||
|
These flags allow you to override the environment variables for the LLM backend.
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
| :--- | :--- |
|
||||||
|
| `--llm-backend <backend>` | Backend to use (`openai`, `ollama`, or `vllm`) |
|
||||||
|
| `--llm-model <model>` | The model name to use |
|
||||||
|
| `--llm-api-key <key>` | API key for the LLM backend |
|
||||||
|
| `--llm-base-url <url>` | Base URL for the LLM backend |
|
||||||
|
|
||||||
|
##### Pipeline Execution
|
||||||
|
| Flag | Description |
|
||||||
|
| :--- | :--- |
|
||||||
|
| `--run-pipeline` | Starts the main orchestration pipeline (TUI + STT + LLM) |
|
||||||
|
|
||||||
|
##### Example Command
|
||||||
|
|
||||||
|
To run the live orchestration pipeline using the configuration specified in your `env.sh`, you can use:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py --run-pipeline \
|
||||||
|
--llm-backend vllm \
|
||||||
|
--llm-model google/gemma-4-26b-a4b-it \
|
||||||
|
--llm-api-key no-key-required \
|
||||||
|
--whisper-model medium \
|
||||||
|
--llm-base-url https://vllm.tipsy.codes/v1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Text Editors
|
||||||
|
|
||||||
|
Since data is stored in Markdown and JSON, you can use any editor (VS Code, Vim, Obsidian) to manually refine your campaign data.
|
||||||
|
|
||||||
## Technical Stack
|
## Technical Stack
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,15 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from src.pipeline.orchestrator import PipelineOrchestrator
|
||||||
from src.rag.manager import RAGManager
|
from src.rag.manager import RAGManager
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="D&D Helpers CLI")
|
parser = argparse.ArgumentParser(description="D&D Helpers CLI")
|
||||||
|
|
||||||
|
# RAG Ingestion Arguments
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ingest-pdf",
|
"--ingest-pdf",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -21,9 +26,73 @@ def main():
|
|||||||
help="Path to a directory of markdown files to ingest into the RAG system",
|
help="Path to a directory of markdown files to ingest into the RAG system",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# LLM Configuration Arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-backend",
|
||||||
|
type=str,
|
||||||
|
choices=["openai", "ollama", "vllm"],
|
||||||
|
default=os.environ.get("LLM_BACKEND", "openai"),
|
||||||
|
help="LLM backend to use",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-model",
|
||||||
|
type=str,
|
||||||
|
default=os.environ.get("LLM_MODEL", "gpt-4o"),
|
||||||
|
help="The model to use for processing",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-api-key",
|
||||||
|
type=str,
|
||||||
|
default=os.environ.get("OPENAI_API_KEY"),
|
||||||
|
help="API key for the LLM backend",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-base-url",
|
||||||
|
type=str,
|
||||||
|
default=os.environ.get("OPENAI_BASE_URL"),
|
||||||
|
help="Base URL for the LLM backend",
|
||||||
|
)
|
||||||
|
|
||||||
|
# STT Configuration Arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-model",
|
||||||
|
type=str,
|
||||||
|
default=os.environ.get("WHISPER_MODEL", "turbo"),
|
||||||
|
help="The Whisper model to use for STT",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pipeline Execution Argument
|
||||||
|
parser.add_argument(
|
||||||
|
"--run-pipeline",
|
||||||
|
action="store_true",
|
||||||
|
help="Start the main orchestration pipeline (TUI + STT + LLM)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
rag_manager = RAGManager()
|
llm_config = {
|
||||||
|
"backend": args.llm_backend,
|
||||||
|
"model": args.llm_model,
|
||||||
|
"api_key": args.llm_api_key,
|
||||||
|
"base_url": args.llm_base_url,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Remove None values to allow defaults to take over if not provided
|
||||||
|
llm_config = {k: v for k, v in llm_config.items() if v is not None}
|
||||||
|
|
||||||
|
if args.run_pipeline:
|
||||||
|
async def run_pipeline():
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
orchestrator = PipelineOrchestrator(loop, llm_config=llm_config, whisper_model=args.whisper_model)
|
||||||
|
try:
|
||||||
|
await orchestrator.run()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
orchestrator.stop()
|
||||||
|
|
||||||
|
asyncio.run(run_pipeline())
|
||||||
|
return
|
||||||
|
|
||||||
|
rag_manager = RAGManager(llm_config=llm_config)
|
||||||
|
|
||||||
if args.ingest_pdf:
|
if args.ingest_pdf:
|
||||||
print(f"Ingesting PDF: {args.ingest_pdf}...")
|
print(f"Ingesting PDF: {args.ingest_pdf}...")
|
||||||
@@ -40,8 +109,8 @@ def main():
|
|||||||
rag_manager.ingest_directory(args.ingest_dir)
|
rag_manager.ingest_directory(args.ingest_dir)
|
||||||
print("Directory ingestion complete.")
|
print("Directory ingestion complete.")
|
||||||
|
|
||||||
if not any([args.ingest_pdf, args.ingest_file, args.ingest_dir]):
|
if not any([args.ingest_pdf, args.ingest_file, args.ingest_dir, args.run_pipeline]):
|
||||||
print("Hello from dnd-helpers!")
|
print("Hello from dnd-helpers! Use --help to see available commands.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -9,3 +9,5 @@ python-dotenv
|
|||||||
llama-index
|
llama-index
|
||||||
chromadb
|
chromadb
|
||||||
pdfplumber
|
pdfplumber
|
||||||
|
llama-index-embeddings-huggingface
|
||||||
|
llama-index-vector_stores-chroma
|
||||||
|
|||||||
+32
-6
@@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from posix import system
|
from posix import system
|
||||||
from this import s
|
from this import s
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
@@ -23,6 +24,7 @@ class LLMProcessor:
|
|||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
|
backend: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the LLMProcessor.
|
Initializes the LLMProcessor.
|
||||||
@@ -30,14 +32,16 @@ class LLMProcessor:
|
|||||||
:param api_key: OpenAI API key. If None, it looks for OPENAI_API_KEY in environment variables.
|
:param api_key: OpenAI API key. If None, it looks for OPENAI_API_KEY in environment variables.
|
||||||
:param base_url: OpenAI-compatible base URL (e.g., for vLLM).
|
:param base_url: OpenAI-compatible base URL (e.g., for vLLM).
|
||||||
:param model: The model to use for processing. If None, it looks for LLM_MODEL in environment variables.
|
:param model: The model to use for processing. If None, it looks for LLM_MODEL in environment variables.
|
||||||
|
:param backend: The LLM backend to use (openai, ollama, or vllm).
|
||||||
"""
|
"""
|
||||||
backend = os.environ.get("LLM_BACKEND", "openai").lower()
|
# Use provided backend or fallback to environment variable
|
||||||
|
backend_env = backend or os.environ.get("LLM_BACKEND", "openai").lower()
|
||||||
|
|
||||||
if backend == "ollama":
|
if backend_env == "ollama":
|
||||||
# Ollama's OpenAI-compatible API
|
# Ollama's OpenAI-compatible API
|
||||||
final_base_url = base_url or "http://localhost:11434/v1"
|
final_base_url = base_url or "http://localhost:11434/v1"
|
||||||
final_api_key = api_key or "ollama"
|
final_api_key = api_key or "ollama"
|
||||||
elif backend == "vllm":
|
elif backend_env == "vllm":
|
||||||
# Remote vLLM server
|
# Remote vLLM server
|
||||||
final_base_url = base_url or os.environ.get("OPENAI_BASE_URL")
|
final_base_url = base_url or os.environ.get("OPENAI_BASE_URL")
|
||||||
final_api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
final_api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||||
@@ -45,18 +49,19 @@ class LLMProcessor:
|
|||||||
final_base_url = base_url or os.environ.get("OPENAI_BASE_URL")
|
final_base_url = base_url or os.environ.get("OPENAI_BASE_URL")
|
||||||
final_api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
final_api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
logger.info(f"Using LLM backend: {backend_env}")
|
||||||
try:
|
try:
|
||||||
self.client = OpenAI(
|
self.client = OpenAI(
|
||||||
api_key=final_api_key,
|
api_key=final_api_key,
|
||||||
base_url=final_base_url,
|
base_url=final_base_url,
|
||||||
)
|
)
|
||||||
# Simple connectivity check for local backends
|
# Simple connectivity check for local backends
|
||||||
if backend == "ollama":
|
if backend_env == "ollama":
|
||||||
# We can't easily check connectivity without making a call,
|
# We can't easily check connectivity without making a call,
|
||||||
# but we can ensure the client is initialized.
|
# but we can ensure the client is initialized.
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error initializing LLM client for backend {backend}: {e}")
|
logger.error(f"Error initializing LLM client for backend {backend_env}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
self.model = model or os.environ.get("LLM_MODEL", "gpt-4o")
|
self.model = model or os.environ.get("LLM_MODEL", "gpt-4o")
|
||||||
@@ -96,15 +101,36 @@ class LLMProcessor:
|
|||||||
|
|
||||||
messages.append({"role": "user", "content": user_prompt})
|
messages.append({"role": "user", "content": user_prompt})
|
||||||
|
|
||||||
|
# Debugging: Dump inputs
|
||||||
|
logger.debug("--- LLM CALL START ---")
|
||||||
|
logger.debug(f"Model: {self.model}")
|
||||||
|
logger.debug(f"Messages: {messages}")
|
||||||
|
if response_format:
|
||||||
|
logger.debug(f"Response Format: {response_format}")
|
||||||
|
logger.debug("--- LLM CALL END ---")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
start_time = time.perf_counter()
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
extra_body={"enable_thinking": False},
|
extra_body={
|
||||||
|
"chat_template_kwargs": {
|
||||||
|
"enable_thinking": False
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
elapsed_time = time.perf_counter() - start_time
|
||||||
|
logger.info(f"LLM request completed in {elapsed_time:.2f}s")
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
|
# Debugging: Dump outputs
|
||||||
|
logger.debug("--- LLM RESPONSE START ---")
|
||||||
|
logger.debug(f"Content: {content}")
|
||||||
|
logger.debug("--- LLM RESPONSE END ---")
|
||||||
|
|
||||||
return self._strip_markdown_code_blocks(content)
|
return self._strip_markdown_code_blocks(content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LLM Error: {e}")
|
logger.error(f"LLM Error: {e}")
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ CONSTRAINTS:
|
|||||||
- OUTPUT ONLY VALID JSON.
|
- OUTPUT ONLY VALID JSON.
|
||||||
- If no relevant information is found, return empty lists for all keys.
|
- If no relevant information is found, return empty lists for all keys.
|
||||||
- If a character name is not specified (e.g., "Your character"), use "Player Character".
|
- If a character name is not specified (e.g., "Your character"), use "Player Character".
|
||||||
|
- Do not repeat lore if it is already known; only provide new or updated facts.
|
||||||
|
|
||||||
Strict Output Format:
|
Strict Output Format:
|
||||||
Return a JSON object with exactly these keys:
|
Return a JSON object with exactly these keys:
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from src.ui.tui import ConfirmationApp
|
|||||||
|
|
||||||
# Configure logging to write to a file instead of stdout
|
# Configure logging to write to a file instead of stdout
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.DEBUG,
|
||||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||||
handlers=[
|
handlers=[
|
||||||
logging.FileHandler("pipeline.log"),
|
logging.FileHandler("pipeline.log"),
|
||||||
@@ -39,14 +39,20 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class PipelineOrchestrator:
|
class PipelineOrchestrator:
|
||||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
def __init__(
|
||||||
|
self,
|
||||||
|
loop: asyncio.AbstractEventLoop,
|
||||||
|
llm_config: Optional[dict] = None,
|
||||||
|
whisper_model: str = "base",
|
||||||
|
):
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
|
self.llm_config = llm_config or {}
|
||||||
|
|
||||||
# Modules
|
# Modules
|
||||||
self.listener = AudioListener(loop=self.loop)
|
self.listener = AudioListener(loop=self.loop)
|
||||||
self.transcriber = Transcriber(model_size="base")
|
self.transcriber = Transcriber(model_size=whisper_model, device="cuda")
|
||||||
self.processor = LLMProcessor()
|
self.processor = LLMProcessor(**self.llm_config)
|
||||||
self.rag_manager = RAGManager()
|
self.rag_manager = RAGManager(llm_config=self.llm_config)
|
||||||
|
|
||||||
# Queues
|
# Queues
|
||||||
self.stt_to_clean_queue = asyncio.Queue()
|
self.stt_to_clean_queue = asyncio.Queue()
|
||||||
@@ -56,6 +62,9 @@ class PipelineOrchestrator:
|
|||||||
self.log_queue = asyncio.Queue()
|
self.log_queue = asyncio.Queue()
|
||||||
self.persistence_queue = asyncio.Queue()
|
self.persistence_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
# Synchronization
|
||||||
|
self.transcription_event = asyncio.Event()
|
||||||
|
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
|
||||||
# Conversation history for context
|
# Conversation history for context
|
||||||
@@ -83,11 +92,12 @@ class PipelineOrchestrator:
|
|||||||
|
|
||||||
return f"Conversation History:\n{context_text}\n\n"
|
return f"Conversation History:\n{context_text}\n\n"
|
||||||
|
|
||||||
async def stt_worker(self):
|
async def stt_collector_worker(self):
|
||||||
"""
|
"""
|
||||||
Worker that handles STT: Audio -> Text.
|
Worker that handles STT Collection: Audio -> Buffer.
|
||||||
|
This task is highly responsive and only manages the buffer.
|
||||||
"""
|
"""
|
||||||
logger.info("STT Worker started.")
|
logger.info("STT Collector Worker started.")
|
||||||
while self.is_running:
|
while self.is_running:
|
||||||
try:
|
try:
|
||||||
# Get audio chunk from listener
|
# Get audio chunk from listener
|
||||||
@@ -104,33 +114,68 @@ class PipelineOrchestrator:
|
|||||||
):
|
):
|
||||||
self.audio_buffer.pop(0)
|
self.audio_buffer.pop(0)
|
||||||
|
|
||||||
# Concatenate buffer for transcription
|
# Signal the transcription worker that new data is available
|
||||||
full_audio = np.concatenate(self.audio_buffer)
|
self.transcription_event.set()
|
||||||
|
|
||||||
# Transcribe (WhisperX now returns a list of (speaker, text, start, end))
|
except Exception as e:
|
||||||
|
logger.error(f"STT Collector Worker error: {e}")
|
||||||
|
|
||||||
|
# Small sleep to prevent tight loop
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
async def stt_transcription_worker(self):
|
||||||
|
"""
|
||||||
|
Worker that handles STT Transcription: Buffer -> Text.
|
||||||
|
This task handles the heavy lifting in a separate thread.
|
||||||
|
"""
|
||||||
|
logger.info("STT Transcription Worker started.")
|
||||||
|
while self.is_running:
|
||||||
|
try:
|
||||||
|
# Wait for a signal that new data is available
|
||||||
|
await self.transcription_event.wait()
|
||||||
|
self.transcription_event.clear()
|
||||||
|
|
||||||
|
# 1. Take a snapshot of the current buffer to avoid race conditions
|
||||||
|
# while the collector is appending new chunks.
|
||||||
|
buffer_snapshot = list(self.audio_buffer)
|
||||||
|
if not buffer_snapshot:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 2. Perform transcription in a separate thread.
|
||||||
|
# We pass the snapshot to the helper which handles concatenation and transcription.
|
||||||
results = await asyncio.to_thread(
|
results = await asyncio.to_thread(
|
||||||
self.transcriber.transcribe, full_audio
|
self._transcribe_buffer_snapshot, buffer_snapshot
|
||||||
)
|
)
|
||||||
|
|
||||||
# Filter for only new segments that start after the last processed segment
|
# 3. Filter for only new segments that start after the last processed segment
|
||||||
new_segments = [
|
new_segments = [
|
||||||
res for res in results if res[2] >= self.last_processed_end_time
|
res for res in results if res[2] >= self.last_processed_end_time
|
||||||
]
|
]
|
||||||
|
|
||||||
if new_segments:
|
if new_segments:
|
||||||
for speaker, text, start, end in new_segments:
|
for speaker, text, start, end in new_segments:
|
||||||
logger.info(f"Transcribed: [{speaker}] {text}")
|
logger.info(f"STT Raw Transcription: [{speaker}] {text}")
|
||||||
|
# Push raw transcription to log queue for UI visibility
|
||||||
|
await self.log_queue.put(f"[{speaker}] {text}")
|
||||||
await self.stt_to_clean_queue.put((speaker, text))
|
await self.stt_to_clean_queue.put((speaker, text))
|
||||||
self.last_processed_end_time = max(
|
self.last_processed_end_time = max(
|
||||||
self.last_processed_end_time, end
|
self.last_processed_end_time, end
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"STT Worker error: {e}")
|
logger.error(f"STT Transcription Worker error: {e}")
|
||||||
|
|
||||||
# Small sleep to prevent tight loop if get_chunk is fast
|
# Small sleep to prevent tight loop
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
def _transcribe_buffer_snapshot(self, buffer_snapshot):
|
||||||
|
"""
|
||||||
|
Helper method to be run in a thread.
|
||||||
|
Concatenates the buffer snapshot and transcribes it.
|
||||||
|
"""
|
||||||
|
full_audio = np.concatenate(buffer_snapshot)
|
||||||
|
return self.transcriber.transcribe(full_audio)
|
||||||
|
|
||||||
async def clean_worker(self):
|
async def clean_worker(self):
|
||||||
"""
|
"""
|
||||||
Worker that handles Text Cleaning: Raw STT -> Filtered Text.
|
Worker that handles Text Cleaning: Raw STT -> Filtered Text.
|
||||||
@@ -203,6 +248,7 @@ class PipelineOrchestrator:
|
|||||||
|
|
||||||
while self.is_running:
|
while self.is_running:
|
||||||
try:
|
try:
|
||||||
|
logger.info("LLM Worker: Waiting for input...")
|
||||||
speaker, text = await internal_queue.get()
|
speaker, text = await internal_queue.get()
|
||||||
logger.info(f"LLM Worker: Processing text from {speaker}: {text}")
|
logger.info(f"LLM Worker: Processing text from {speaker}: {text}")
|
||||||
|
|
||||||
@@ -212,6 +258,9 @@ class PipelineOrchestrator:
|
|||||||
# Log the text sent to the LLM for UI affordance
|
# Log the text sent to the LLM for UI affordance
|
||||||
await self.log_queue.put(f"[{speaker}] {text}")
|
await self.log_queue.put(f"[{speaker}] {text}")
|
||||||
|
|
||||||
|
# Log the filtered message being sent to the LLM
|
||||||
|
logger.info(f"LLM Worker: Sending filtered message to LLM: {text}")
|
||||||
|
|
||||||
# Structured extraction using the processor
|
# Structured extraction using the processor
|
||||||
extraction_result = await asyncio.to_thread(
|
extraction_result = await asyncio.to_thread(
|
||||||
self.processor.extract_structured_data,
|
self.processor.extract_structured_data,
|
||||||
@@ -300,7 +349,8 @@ class PipelineOrchestrator:
|
|||||||
|
|
||||||
# Start workers as background tasks
|
# Start workers as background tasks
|
||||||
tasks = [
|
tasks = [
|
||||||
asyncio.create_task(self.stt_worker()),
|
asyncio.create_task(self.stt_collector_worker()),
|
||||||
|
asyncio.create_task(self.stt_transcription_worker()),
|
||||||
asyncio.create_task(self.clean_worker()),
|
asyncio.create_task(self.clean_worker()),
|
||||||
asyncio.create_task(self.llm_worker()),
|
asyncio.create_task(self.llm_worker()),
|
||||||
asyncio.create_task(self.persistence_worker()),
|
asyncio.create_task(self.persistence_worker()),
|
||||||
@@ -328,3 +378,16 @@ class PipelineOrchestrator:
|
|||||||
Stops.
|
Stops.
|
||||||
"""
|
"""
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
orchestrator = PipelineOrchestrator(loop)
|
||||||
|
try:
|
||||||
|
await orchestrator.run()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
orchestrator.stop()
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|||||||
+3
-2
@@ -12,8 +12,9 @@ from src.llm.processor import LLMProcessor
|
|||||||
|
|
||||||
|
|
||||||
class RAGManager:
|
class RAGManager:
|
||||||
def __init__(self, persist_dir: str = "data/rag_index"):
|
def __init__(self, persist_dir: str = "data/rag_index", llm_config: Optional[dict] = None):
|
||||||
self.persist_dir = persist_dir
|
self.persist_dir = persist_dir
|
||||||
|
self.llm_config = llm_config or {}
|
||||||
self.db = chromadb.PersistentClient(path=self.persist_dir)
|
self.db = chromadb.PersistentClient(path=self.persist_dir)
|
||||||
self.collection_name = "phb_collection"
|
self.collection_name = "phb_collection"
|
||||||
|
|
||||||
@@ -110,7 +111,7 @@ class RAGManager:
|
|||||||
if not nodes:
|
if not nodes:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
processor = LLMProcessor()
|
processor = LLMProcessor(**self.llm_config)
|
||||||
|
|
||||||
# Construct the context from retrieved nodes
|
# Construct the context from retrieved nodes
|
||||||
context_text = "\n\n".join(
|
context_text = "\n\n".join(
|
||||||
|
|||||||
+9
-36
@@ -1,9 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import whisperx
|
import whisperx
|
||||||
from whisperx.diarize import DiarizationPipeline
|
|
||||||
|
|
||||||
# Do not call basicConfig here, as it's called in the orchestrator
|
# Do not call basicConfig here, as it's called in the orchestrator
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -11,14 +9,14 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class Transcriber:
|
class Transcriber:
|
||||||
"""
|
"""
|
||||||
Converts audio chunks (numpy arrays) into text and identifies speakers using WhisperX.
|
Converts audio chunks (numpy arrays) into text using WhisperX.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_size="base", device="cpu", compute_type="int8", language="en"
|
self, model_size="base", device="cpu", compute_type="int8", language="en"
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the WhisperX model and diarization pipeline.
|
Initializes the WhisperX model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_size (str): The size of the model to use (e.g., "tiny", "base", "small").
|
model_size (str): The size of the model to use (e.g., "tiny", "base", "small").
|
||||||
@@ -39,28 +37,20 @@ class Transcriber:
|
|||||||
model_size, device=device, compute_type=compute_type
|
model_size, device=device, compute_type=compute_type
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load alignment model (required for accurate speaker assignment)
|
logger.info("WhisperX model loaded successfully.")
|
||||||
# model_dir=None allows automatic model selection based on the language
|
|
||||||
self.align_model, self.align_metadata = whisperx.load_align_model(
|
|
||||||
device=device, model_dir=None, language_code=self.language
|
|
||||||
)
|
|
||||||
|
|
||||||
self.diarize_model = DiarizationPipeline()
|
|
||||||
|
|
||||||
logger.info("WhisperX and Diarization models loaded successfully.")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load WhisperX models: {e}")
|
logger.error(f"Failed to load WhisperX models: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def transcribe(self, audio_chunk):
|
def transcribe(self, audio_chunk):
|
||||||
"""
|
"""
|
||||||
Transcribes an audio chunk and performs speaker diarization.
|
Transcribes an audio chunk.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio_chunk (np.ndarray): The audio data as a numpy array.
|
audio_chunk (np.ndarray): The audio data as a numpy array.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: A list of tuples (speaker_id, text).
|
list: A list of tuples (speaker_id, text, start, end).
|
||||||
"""
|
"""
|
||||||
if audio_chunk is None:
|
if audio_chunk is None:
|
||||||
return []
|
return []
|
||||||
@@ -73,35 +63,18 @@ class Transcriber:
|
|||||||
# batch_size is set to 16 for efficiency; can be adjusted based on VRAM
|
# batch_size is set to 16 for efficiency; can be adjusted based on VRAM
|
||||||
result = self.model.transcribe(audio, batch_size=16)
|
result = self.model.transcribe(audio, batch_size=16)
|
||||||
|
|
||||||
# 2. Perform alignment
|
# Extract ("Unknown", text, start, end) tuples from the transcription result
|
||||||
# Alignment is necessary for the assign_words_to_speakers step
|
|
||||||
result_a = whisperx.align(
|
|
||||||
result["segments"],
|
|
||||||
self.align_model,
|
|
||||||
self.align_metadata,
|
|
||||||
audio,
|
|
||||||
self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Perform diarization
|
|
||||||
diarize_segments = self.diarize_model(audio)
|
|
||||||
|
|
||||||
# 4. Align transcription segments with speakers
|
|
||||||
result_final = whisperx.assign_word_speakers(diarize_segments, result_a)
|
|
||||||
|
|
||||||
# Extract (speaker_id, text, start, end) tuples from the final result
|
|
||||||
output = []
|
output = []
|
||||||
for segment in result_final.get("segments", []):
|
for segment in result.get("segments", []):
|
||||||
speaker = segment.get("speaker", "Unknown")
|
|
||||||
text = segment.get("text", "").strip()
|
text = segment.get("text", "").strip()
|
||||||
start = segment.get("start", 0.0)
|
start = segment.get("start", 0.0)
|
||||||
end = segment.get("end", 0.0)
|
end = segment.get("end", 0.0)
|
||||||
if text:
|
if text:
|
||||||
output.append((speaker, text, start, end))
|
output.append(("Unknown", text, start, end))
|
||||||
|
|
||||||
return output
|
return output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Transcription/Diarization error: {e}")
|
logger.error(f"Transcription error: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|||||||
+32
-28
@@ -16,20 +16,26 @@ from textual.widgets import (
|
|||||||
Static,
|
Static,
|
||||||
)
|
)
|
||||||
|
|
||||||
from src.llm.models import CharacterStateUpdate, ExtractionResult, LoreUpdate
|
from src.llm.models import CharacterStateUpdate, ContextUpdate, ExtractionResult, LoreUpdate
|
||||||
from src.persistence.characters import update_character_state
|
from src.persistence.characters import update_character_state
|
||||||
from src.persistence.lore import update_lore
|
from src.persistence.lore import update_lore
|
||||||
|
|
||||||
|
|
||||||
class EditModal(ModalScreen):
|
class EditModal(ModalScreen):
|
||||||
def __init__(self, initial_text: str, on_save: callable):
|
def __init__(self, initial_text: str, initial_type: str, initial_target: str, on_save: callable):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.initial_text = initial_text
|
self.initial_text = initial_text
|
||||||
|
self.initial_type = initial_type
|
||||||
|
self.initial_target = initial_target
|
||||||
self.on_save = on_save
|
self.on_save = on_save
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
def compose(self) -> ComposeResult:
|
||||||
with Vertical(id="modal-container"):
|
with Vertical(id="modal-container"):
|
||||||
yield Label("Edit Fact Content:")
|
yield Label("Type:")
|
||||||
|
yield Input(value=self.initial_type, id="edit-type")
|
||||||
|
yield Label("Target:")
|
||||||
|
yield Input(value=self.initial_target, id="edit-target")
|
||||||
|
yield Label("Content:")
|
||||||
yield Input(value=self.initial_text, id="edit-input")
|
yield Input(value=self.initial_text, id="edit-input")
|
||||||
with Horizontal(id="modal-actions"):
|
with Horizontal(id="modal-actions"):
|
||||||
yield Button("Save", id="btn-save")
|
yield Button("Save", id="btn-save")
|
||||||
@@ -38,7 +44,9 @@ class EditModal(ModalScreen):
|
|||||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
if event.button.id == "btn-save":
|
if event.button.id == "btn-save":
|
||||||
edit_input = self.query_one("#edit-input", Input)
|
edit_input = self.query_one("#edit-input", Input)
|
||||||
self.on_save(edit_input.value)
|
type_input = self.query_one("#edit-type", Input)
|
||||||
|
target_input = self.query_one("#edit-target", Input)
|
||||||
|
self.on_save(edit_input.value, type_input.value, target_input.value)
|
||||||
self.dismiss()
|
self.dismiss()
|
||||||
elif event.button.id == "btn-cancel":
|
elif event.button.id == "btn-cancel":
|
||||||
self.dismiss()
|
self.dismiss()
|
||||||
@@ -68,32 +76,27 @@ class ConfirmationApp(App):
|
|||||||
}
|
}
|
||||||
|
|
||||||
#pending-facts-table {
|
#pending-facts-table {
|
||||||
height: 40%;
|
height: 30%;
|
||||||
border: solid white;
|
border: solid white;
|
||||||
}
|
}
|
||||||
|
|
||||||
#llm-input-container {
|
#llm-input-container {
|
||||||
height: 10%;
|
height: 10%;
|
||||||
border: solid white;
|
border: solid white;
|
||||||
padding: 1;
|
padding: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#context-pane {
|
#context-pane {
|
||||||
height: 50%;
|
height: 60%;
|
||||||
border: solid white;
|
border: solid white;
|
||||||
}
|
}
|
||||||
|
|
||||||
#log-pane {
|
#log-pane {
|
||||||
height: 30%;
|
height: 100%;
|
||||||
border: solid white;
|
border: solid white;
|
||||||
background: #111;
|
background: #111;
|
||||||
}
|
}
|
||||||
|
|
||||||
#log-footer {
|
|
||||||
height: 70%;
|
|
||||||
border: solid white;
|
|
||||||
}
|
|
||||||
|
|
||||||
#modal-container {
|
#modal-container {
|
||||||
width: 60%;
|
width: 60%;
|
||||||
height: auto;
|
height: auto;
|
||||||
@@ -109,7 +112,7 @@ class ConfirmationApp(App):
|
|||||||
align: right middle;
|
align: right middle;
|
||||||
}
|
}
|
||||||
|
|
||||||
#edit-input {
|
#edit-input, #edit-type, #edit-target {
|
||||||
margin: 1 0;
|
margin: 1 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,18 +158,11 @@ class ConfirmationApp(App):
|
|||||||
Horizontal(
|
Horizontal(
|
||||||
Vertical(
|
Vertical(
|
||||||
DataTable(id="pending-facts-table"),
|
DataTable(id="pending-facts-table"),
|
||||||
Vertical(
|
Input(placeholder="Message LLM...", id="llm-input"),
|
||||||
Input(placeholder="Message LLM...", id="llm-input"),
|
|
||||||
id="llm-input-container",
|
|
||||||
),
|
|
||||||
ListView(id="context-pane"),
|
ListView(id="context-pane"),
|
||||||
id="left-pane",
|
id="left-pane",
|
||||||
),
|
),
|
||||||
Vertical(
|
ListView(id="log-pane"),
|
||||||
ListView(id="log-pane"),
|
|
||||||
Static("LATEST LLM INPUTS", id="log-footer"),
|
|
||||||
id="right-pane",
|
|
||||||
),
|
|
||||||
id="content-wrapper",
|
id="content-wrapper",
|
||||||
),
|
),
|
||||||
id="main-container",
|
id="main-container",
|
||||||
@@ -199,7 +195,7 @@ class ConfirmationApp(App):
|
|||||||
table = self.query_one("#pending-facts-table", DataTable)
|
table = self.query_one("#pending-facts-table", DataTable)
|
||||||
if isinstance(update, LoreUpdate):
|
if isinstance(update, LoreUpdate):
|
||||||
table.add_row(
|
table.add_row(
|
||||||
"Lore", update.entity_name or "General", update.content, key=str(index)
|
update.category, update.entity_name or "General", update.content, key=str(index)
|
||||||
)
|
)
|
||||||
elif isinstance(update, CharacterStateUpdate):
|
elif isinstance(update, CharacterStateUpdate):
|
||||||
change_text = f"HP: {update.hp_change or 0}"
|
change_text = f"HP: {update.hp_change or 0}"
|
||||||
@@ -218,8 +214,6 @@ class ConfirmationApp(App):
|
|||||||
elif isinstance(update, ContextUpdate):
|
elif isinstance(update, ContextUpdate):
|
||||||
display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}"
|
display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}"
|
||||||
context_list = self.query_one("#context-pane", ListView)
|
context_list = self.query_one("#context-pane", ListView)
|
||||||
# ListView.insert takes an *iterable* of ListItems; passing a
|
|
||||||
# bare ListItem raises TypeError because ListItem is not iterable.
|
|
||||||
# Insert at the top to show most recent first.
|
# Insert at the top to show most recent first.
|
||||||
await context_list.insert(0, [ListItem(Static(display_text))])
|
await context_list.insert(0, [ListItem(Static(display_text))])
|
||||||
if hasattr(self.llm_to_ui_queue, "task_done"):
|
if hasattr(self.llm_to_ui_queue, "task_done"):
|
||||||
@@ -294,24 +288,34 @@ class ConfirmationApp(App):
|
|||||||
|
|
||||||
update = self.pending_updates[row_index]
|
update = self.pending_updates[row_index]
|
||||||
initial_text = ""
|
initial_text = ""
|
||||||
|
initial_type = ""
|
||||||
|
initial_target = ""
|
||||||
|
|
||||||
if isinstance(update, LoreUpdate):
|
if isinstance(update, LoreUpdate):
|
||||||
initial_text = update.content
|
initial_text = update.content
|
||||||
|
initial_type = update.category
|
||||||
|
initial_target = update.entity_name or ""
|
||||||
elif isinstance(update, CharacterStateUpdate):
|
elif isinstance(update, CharacterStateUpdate):
|
||||||
initial_text = str(update.hp_change or 0)
|
initial_text = str(update.hp_change or 0)
|
||||||
|
initial_type = "Char"
|
||||||
|
initial_target = update.character_name
|
||||||
|
|
||||||
def save_callback(new_text: str):
|
def save_callback(new_text: str, new_type: str, new_target: str):
|
||||||
if isinstance(update, LoreUpdate):
|
if isinstance(update, LoreUpdate):
|
||||||
update.content = new_text
|
update.content = new_text
|
||||||
|
update.category = new_type
|
||||||
|
update.entity_name = new_target if new_target else None
|
||||||
elif isinstance(update, CharacterStateUpdate):
|
elif isinstance(update, CharacterStateUpdate):
|
||||||
try:
|
try:
|
||||||
update.hp_change = int(new_text)
|
update.hp_change = int(new_text)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
update.character_name = new_target
|
||||||
|
|
||||||
# Update the table
|
# Update the table
|
||||||
self.refresh_table()
|
self.refresh_table()
|
||||||
|
|
||||||
self.push_screen(EditModal(initial_text, save_callback))
|
self.push_screen(EditModal(initial_text, initial_type, initial_target, save_callback))
|
||||||
|
|
||||||
def remove_update(self, index: int) -> None:
|
def remove_update(self, index: int) -> None:
|
||||||
del self.pending_updates[index]
|
del self.pending_updates[index]
|
||||||
|
|||||||
Reference in New Issue
Block a user