444 lines
15 KiB
Python
444 lines
15 KiB
Python
"""
|
|
Discord Voice Bot - Simple GLaDOS Voice Version
|
|
Uses Wyoming Whisper for STT, Ollama for LLM, HTTP TTS for GLaDOS voice.
|
|
Works WITHOUT discord.sinks (manual audio capture)
|
|
"""
|
|
|
|
import logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
import asyncio
|
|
import io
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import wave
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import numpy as np
|
|
import requests
|
|
import yaml
|
|
import discord
|
|
from discord.ext import commands
|
|
import json
|
|
|
|
# Import Wyoming protocol
|
|
try:
|
|
from wyoming.client import AsyncTcpClient
|
|
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
|
from wyoming.asr import Transcribe, Transcript
|
|
WYOMING_AVAILABLE = True
|
|
except ImportError:
|
|
logger.warning("Wyoming library not available")
|
|
WYOMING_AVAILABLE = False
|
|
|
|
# Optional: Import GLaDOS ASR (Windows path)
|
|
sys.path.insert(0, r'C:\glados\src')
|
|
try:
|
|
from glados.ASR import get_audio_transcriber
|
|
GLADOS_ASR_AVAILABLE = True
|
|
logger.info("GLaDOS ASR module found")
|
|
except ImportError:
|
|
GLADOS_ASR_AVAILABLE = False
|
|
logger.warning("GLaDOS ASR not available")
|
|
|
|
|
|
# Initialize GLaDOS ASR if available (fallback)
|
|
parakeet_asr = None
|
|
if GLADOS_ASR_AVAILABLE:
|
|
try:
|
|
logger.info("Loading GLaDOS Parakeet ASR model...")
|
|
parakeet_asr = get_audio_transcriber(engine_type="tdt")
|
|
logger.info("Parakeet ASR loaded")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load Parakeet ASR: {e}")
|
|
|
|
|
|
class WyomingWhisper:
|
|
"""Speech-to-text using Wyoming Whisper."""
|
|
def __init__(self, host="localhost", port=10300):
|
|
self.host = host
|
|
self.port = port
|
|
|
|
async def transcribe(self, audio_bytes):
|
|
"""Transcribe audio using Wyoming Whisper."""
|
|
if not WYOMING_AVAILABLE:
|
|
return None
|
|
try:
|
|
async with AsyncTcpClient(self.host, self.port) as client:
|
|
await client.write_event(Transcribe().event())
|
|
|
|
chunk_size = 4096
|
|
rate = 16000
|
|
width = 2
|
|
channels = 1
|
|
|
|
await client.write_event(AudioStart(
|
|
rate=rate, width=width, channels=channels
|
|
).event())
|
|
|
|
for i in range(0, len(audio_bytes), chunk_size):
|
|
chunk = audio_bytes[i:i + chunk_size]
|
|
await client.write_event(AudioChunk(
|
|
audio=chunk, rate=rate, width=width, channels=channels
|
|
).event())
|
|
|
|
await client.write_event(AudioStop().event())
|
|
|
|
while True:
|
|
event = await client.read_event()
|
|
if event is None:
|
|
break
|
|
if Transcript.is_type(event.type):
|
|
transcript = Transcript.from_event(event)
|
|
return transcript.text
|
|
except Exception as e:
|
|
logger.error(f"Wyoming Whisper error: {e}")
|
|
return None
|
|
|
|
|
|
class ParakeetASR:
|
|
"""Speech-to-text using GLaDOS Parakeet ASR (fallback)."""
|
|
async def transcribe(self, audio_bytes):
|
|
if not parakeet_asr:
|
|
return None
|
|
try:
|
|
audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
|
|
if len(audio_np) > 48000 * 30:
|
|
audio_np = audio_np[:48000 * 30]
|
|
ratio = 48000 // 16000
|
|
audio_16k = audio_np[::ratio].astype(np.int16)
|
|
audio_float = audio_16k.astype(np.float32)
|
|
text = parakeet_asr.transcribe(audio_float)
|
|
return text.strip() if text else None
|
|
except Exception as e:
|
|
logger.error(f"Parakeet ASR error: {e}")
|
|
return None
|
|
|
|
|
|
class HTTPTTS:
|
|
"""Text-to-speech using HTTP API."""
|
|
def __init__(self, base_url, voice="glados"):
|
|
self.base_url = base_url
|
|
self.voice = voice
|
|
|
|
async def synthesize(self, text):
|
|
try:
|
|
response = requests.post(
|
|
f"{self.base_url}/v1/audio/speech",
|
|
json={"input": text, "voice": self.voice},
|
|
timeout=30
|
|
)
|
|
if response.status_code in [200, 201]:
|
|
logger.info(f"Got TTS audio: {len(response.content)} bytes")
|
|
return response.content
|
|
except Exception as e:
|
|
logger.error(f"TTS error: {e}")
|
|
return None
|
|
|
|
|
|
class OllamaClient:
|
|
"""Client for Ollama."""
|
|
def __init__(self, base_url, model):
|
|
self.base_url = base_url
|
|
self.model = model
|
|
|
|
def generate(self, user_message):
|
|
try:
|
|
url = f"{self.base_url}/api/generate"
|
|
payload = {
|
|
"model": self.model,
|
|
"prompt": f"Keep responses concise and conversational. User: {user_message}",
|
|
"stream": False
|
|
}
|
|
response = requests.post(url, json=payload, timeout=30)
|
|
result = response.json()
|
|
return result.get('response', '').strip()
|
|
except Exception as e:
|
|
logger.error(f"Ollama error: {e}")
|
|
return "I'm sorry, I couldn't process that."
|
|
|
|
|
|
# Load config
|
|
config_path = os.path.join(os.path.dirname(__file__), 'config.yaml')
|
|
with open(config_path, 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
# Components
|
|
whisper_stt = WyomingWhisper(config['whisper']['host'], config['whisper']['port']) if WYOMING_AVAILABLE else None
|
|
parakeet_stt = ParakeetASR()
|
|
http_tts = HTTPTTS(config['tts']['http_url'], config['tts'].get('voice', 'glados'))
|
|
ollama = OllamaClient(config['ollama']['base_url'], config['ollama']['model'])
|
|
|
|
|
|
class VoiceBot(commands.Bot):
|
|
"""Discord voice bot WITHOUT sinks dependency."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
intents = discord.Intents.default()
|
|
intents.message_content = True
|
|
intents.voice_states = True
|
|
super().__init__(command_prefix="!", intents=intents, *args, **kwargs)
|
|
self.voice_client = None
|
|
self.config = config
|
|
self._recording = False
|
|
self._audio_buffer = bytearray()
|
|
|
|
async def on_ready(self):
|
|
logger.info(f"Bot ready! {self.user.name} ({self.user.id})")
|
|
logger.info("Use !join to connect to voice channel, !leave to disconnect")
|
|
|
|
async def on_message(self, message):
|
|
if message.author == self.user:
|
|
return
|
|
await self.process_commands(message)
|
|
|
|
async def join_voice_channel(self, channel):
|
|
if self.voice_client:
|
|
await self.voice_client.disconnect()
|
|
self.voice_client = await channel.connect()
|
|
logger.info(f"Joined voice channel: {channel.name}")
|
|
|
|
def convert_discord_audio_to_parakeet(self, audio_bytes):
|
|
"""Convert Discord 48kHz stereo PCM to 16kHz mono float32 for Parakeet."""
|
|
try:
|
|
# Discord audio is 48kHz, stereo, 16-bit PCM
|
|
# Convert bytes to int16 numpy array
|
|
audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
|
|
|
|
# Stereo to mono: average left and right channels
|
|
audio_np = audio_np.reshape(-1, 2).mean(axis=1).astype(np.int16)
|
|
|
|
# Resample 48kHz to 16kHz (divide by 3)
|
|
audio_16k = audio_np[::3]
|
|
|
|
# Convert int16 to float32 (normalize to [-1.0, 1.0])
|
|
audio_float = audio_16k.astype(np.float32) / 32768.0
|
|
|
|
return audio_float
|
|
except Exception as e:
|
|
logger.error(f"Audio conversion error: {e}")
|
|
return None
|
|
|
|
async def record_audio(self, duration=5):
|
|
"""Record audio from voice channel for specified duration."""
|
|
if not self.voice_client:
|
|
logger.warning("Not in voice channel")
|
|
return None
|
|
|
|
self._recording = True
|
|
self._audio_buffer = bytearray()
|
|
|
|
logger.info(f"Recording for {duration} seconds...")
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
|
while self._recording and (asyncio.get_event_loop().time() - start_time) < duration:
|
|
try:
|
|
# Try to get audio packet (non-blocking)
|
|
packet = await asyncio.wait_for(
|
|
self.voice_client.receive(),
|
|
timeout=0.1
|
|
)
|
|
if packet and hasattr(packet, 'data'):
|
|
self._audio_buffer.extend(packet.data)
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
except Exception as e:
|
|
logger.debug(f"Recv error: {e}")
|
|
continue
|
|
|
|
self._recording = False
|
|
audio_data = bytes(self._audio_buffer)
|
|
logger.info(f"Recorded {len(audio_data)} bytes")
|
|
return audio_data
|
|
|
|
async def process_voice_command(self, ctx):
|
|
"""Record, transcribe, get LLM response, and speak."""
|
|
await ctx.send("🎙️ Listening... (speak now)")
|
|
|
|
# Record audio
|
|
start_time = asyncio.get_event_loop().time()
|
|
audio_bytes = await self.record_audio(duration=5)
|
|
record_time = asyncio.get_event_loop().time() - start_time
|
|
|
|
if not audio_bytes or len(audio_bytes) < 1000:
|
|
await ctx.send("❌ No audio captured (too quiet or not in voice channel)")
|
|
return
|
|
|
|
await ctx.send(f"📝 Transcribing ({len(audio_bytes)} bytes, {record_time:.1f}s)...")
|
|
|
|
# Convert audio format
|
|
audio_float = self.convert_discord_audio_to_parakeet(audio_bytes)
|
|
if audio_float is None:
|
|
await ctx.send("❌ Audio conversion failed")
|
|
return
|
|
|
|
# Transcribe with Parakeet
|
|
transcribe_start = asyncio.get_event_loop().time()
|
|
try:
|
|
# Run transcription in thread pool (it's CPU intensive)
|
|
loop = asyncio.get_event_loop()
|
|
text = await loop.run_in_executor(
|
|
None,
|
|
lambda: parakeet_asr.transcribe(audio_float)
|
|
)
|
|
transcribe_time = asyncio.get_event_loop().time() - transcribe_start
|
|
except Exception as e:
|
|
logger.error(f"Transcription error: {e}")
|
|
await ctx.send(f"❌ Transcription failed: {e}")
|
|
return
|
|
|
|
if not text or not text.strip():
|
|
await ctx.send("❌ No speech detected")
|
|
return
|
|
|
|
await ctx.send(f"👤 You said: \"{text}\" ({transcribe_time:.1f}s)")
|
|
|
|
# Get LLM response
|
|
llm_start = asyncio.get_event_loop().time()
|
|
response = ollama.generate(text)
|
|
llm_time = asyncio.get_event_loop().time() - llm_start
|
|
|
|
if not response:
|
|
await ctx.send("❌ LLM failed to respond")
|
|
return
|
|
|
|
await ctx.send(f"🤖 GLaDOS: \"{response}\" ({llm_time:.1f}s)")
|
|
|
|
# Synthesize and speak
|
|
tts_start = asyncio.get_event_loop().time()
|
|
audio = await http_tts.synthesize(response)
|
|
tts_time = asyncio.get_event_loop().time() - tts_start
|
|
|
|
if audio:
|
|
await self.play_audio(audio)
|
|
total_time = record_time + transcribe_time + llm_time + tts_time
|
|
await ctx.send(f"⏱️ Total latency: {total_time:.1f}s (rec: {record_time:.1f}, stt: {transcribe_time:.1f}, llm: {llm_time:.1f}, tts: {tts_time:.1f})")
|
|
else:
|
|
await ctx.send("❌ TTS failed")
|
|
|
|
async def play_audio(self, audio_bytes):
|
|
"""Play audio in voice channel."""
|
|
if not self.voice_client:
|
|
logger.warning("Not connected to voice channel")
|
|
return False
|
|
|
|
if audio_bytes[:4] == b'RIFF':
|
|
suffix = '.wav'
|
|
else:
|
|
suffix = '.mp3'
|
|
|
|
# Create a temp file for FFmpeg
|
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp:
|
|
temp.write(audio_bytes)
|
|
temp_path = temp.name
|
|
|
|
try:
|
|
source = discord.FFmpegPCMAudio(temp_path)
|
|
if self.voice_client.is_playing():
|
|
self.voice_client.stop()
|
|
self.voice_client.play(source)
|
|
|
|
# Wait for playback to finish
|
|
while self.voice_client.is_playing():
|
|
await asyncio.sleep(0.1)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error playing audio: {e}")
|
|
return False
|
|
finally:
|
|
try:
|
|
os.unlink(temp_path)
|
|
except:
|
|
pass
|
|
|
|
|
|
bot = VoiceBot()
|
|
|
|
|
|
@bot.command(name='leave')
|
|
async def leave(ctx):
|
|
"""Leave voice channel."""
|
|
if bot.voice_client:
|
|
await bot.voice_client.disconnect()
|
|
bot.voice_client = None
|
|
await ctx.send("Left voice channel.")
|
|
|
|
|
|
@bot.command(name='join')
|
|
async def join(ctx):
|
|
"""Join voice channel."""
|
|
if not ctx.author.voice:
|
|
await ctx.send("You need to be in a voice channel!")
|
|
return
|
|
channel = ctx.author.voice.channel
|
|
await bot.join_voice_channel(channel)
|
|
await ctx.send(f"Joined {channel.name}!")
|
|
|
|
|
|
@bot.command(name='test')
|
|
async def test(ctx, *, text="Hello! This is a test."):
|
|
"""Test TTS."""
|
|
if not bot.voice_client:
|
|
await ctx.send("Not in voice channel! Use !join first.")
|
|
return
|
|
|
|
await ctx.send(f"🎙️ Saying: {text}")
|
|
audio = await http_tts.synthesize(text)
|
|
if audio:
|
|
success = await bot.play_audio(audio)
|
|
if not success:
|
|
await ctx.send("Failed to play audio.")
|
|
else:
|
|
await ctx.send("TTS error.")
|
|
|
|
|
|
@bot.command(name='say')
|
|
async def say(ctx, *, text):
|
|
"""Say text using TTS."""
|
|
await test(ctx, text=text)
|
|
|
|
|
|
@bot.command(name='listen')
|
|
async def listen(ctx):
|
|
"""Record voice for 5 seconds, transcribe, and respond."""
|
|
if not bot.voice_client:
|
|
await ctx.send("Not in voice channel! Use !join first.")
|
|
return
|
|
|
|
if not parakeet_asr:
|
|
await ctx.send("❌ Parakeet ASR not available. Check GLaDOS installation.")
|
|
return
|
|
|
|
await bot.process_voice_command(ctx)
|
|
|
|
|
|
@bot.command(name='ask')
|
|
async def ask(ctx, *, question):
|
|
"""Ask the LLM something (text only, for now)."""
|
|
await ctx.send("🤔 Thinking...")
|
|
response = ollama.generate(question)
|
|
if response:
|
|
await ctx.send(f"💬 {response}")
|
|
# Also speak it if in voice channel
|
|
if bot.voice_client:
|
|
audio = await http_tts.synthesize(response)
|
|
if audio:
|
|
await bot.play_audio(audio)
|
|
else:
|
|
await ctx.send("Failed to get response.")
|
|
|
|
|
|
async def main():
|
|
token = config['discord']['token']
|
|
if token.startswith("YOUR_"):
|
|
logger.error("Configure Discord token in config.yaml!")
|
|
return
|
|
|
|
logger.info("Starting Discord bot...")
|
|
await bot.start(token)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
asyncio.run(main())
|