import gradio as gr import subprocess, json, os, io, tempfile from faster_whisper import WhisperModel from ollama import Client as OllamaClient # ---- CONFIG ---- LLM_MODEL = "llama3.2:3b" # or "mistral:7b", "qwen2.5:3b" WHISPER_SIZE = "small" # "base", "small", "medium" USE_SILERO = True # set False to use Coqui XTTS v2 USE_CONTEXT = False # <— new: disable conversational memory import os USE_REMOTE_OLLAMA = bool(os.getenv("OLLAMA_HOST")) if not USE_REMOTE_OLLAMA: # Transformers fallback for Spaces (CPU-friendly small instruct model) from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline HF_CHAT_MODEL = os.getenv("HF_CHAT_MODEL", "google/gemma-2-2b-it") # small instruct model that runs on CPU HF_TOKEN = os.getenv("HF_TOKEN") _tok = AutoTokenizer.from_pretrained(HF_CHAT_MODEL, token=HF_TOKEN) _mdl = AutoModelForCausalLM.from_pretrained(HF_CHAT_MODEL, token=HF_TOKEN, torch_dtype="auto", device_map="auto") gen = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=256) # ---- STT (faster-whisper) ---- # Run on GPU if available: compute_type="float16", device="cuda" stt_model = WhisperModel(WHISPER_SIZE, device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu", compute_type="float16" if os.environ.get("CUDA_VISIBLE_DEVICES") else "int8") def speech_to_text(audio_path: str) -> str: segments, info = stt_model.transcribe(audio_path, beam_size=1, vad_filter=True, language="en") text = "".join(seg.text for seg in segments).strip() return text # ---- LLM (Ollama) ---- # ollama = OllamaClient(host="http://127.0.0.1:11434") SYSTEM_PROMPT = """You are a friendly AI voice assistant. Reply in one short, natural sentence only. Sound warm and conversational, never formal. Avoid multi-sentence or paragraph answers.""" def chat_with_llm(history_messages, user_text): if USE_REMOTE_OLLAMA: # Only system + current user messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_text}, ] resp = ollama.chat(model=LLM_MODEL, messages=messages) return resp["message"]["content"] else: # Only system + current user prompt = f"{SYSTEM_PROMPT}\nUser: {user_text}\nAssistant:" out = gen(prompt, return_full_text=False, max_new_tokens=25, temperature=0.8, repetition_penalty=1.1,)[0]["generated_text"].split("\n")[0].strip() return out # near top-level (global singletons) _SILERO_TTS = None def tts_silero(text: str) -> str: """ Return path to WAV synthesized by Silero TTS. Uses a cached model instance to avoid re-downloading each request. """ import torch, tempfile import soundfile as sf global _SILERO_TTS if _SILERO_TTS is None: obj = torch.hub.load( repo_or_dir="snakers4/silero-models", model="silero_tts", language="en", speaker="v3_en", trust_repo=True, # avoids interactive trust prompt ) _SILERO_TTS = obj[0] if isinstance(obj, (list, tuple)) else obj model = _SILERO_TTS sample_rate = 48000 speaker = "en_0" audio = model.apply_tts(text=text, speaker=speaker, sample_rate=sample_rate) out_wav = tempfile.mktemp(suffix=".wav") sf.write(out_wav, audio, sample_rate) return out_wav def tts_coqui_xtts(text: str) -> str: """ Returns path to a WAV file synthesized by Coqui XTTS v2 (higher quality; GPU-friendly). """ from TTS.api import TTS tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2") out_wav = tempfile.mktemp(suffix=".wav") tts.tts_to_file(text=text, file_path=out_wav, speaker="female-en-5", language="en") return out_wav def text_to_speech(text: str) -> str: if USE_SILERO: return tts_silero(text) else: return tts_coqui_xtts(text) # ---- Gradio pipeline ---- def pipeline(audio, history): # audio is (sample_rate, np.array) OR a filepath (depends on Gradio version) # Normalize to a temp wav file if audio is None: return history, None, "Please speak something." if isinstance(audio, tuple): # (sr, data) -> write wav import soundfile as sf, numpy as np, tempfile sr, data = audio tmp_in = tempfile.mktemp(suffix=".wav") sf.write(tmp_in, data.astype("float32"), sr) audio_path = tmp_in else: audio_path = audio # path already user_text = speech_to_text(audio_path) if not user_text: return history, None, "Didn't catch that—could you repeat?" reply = chat_with_llm(history, user_text) # Extract the "Reply:" line for TTS; speak only the conversational reply speak_text = reply for tag in ["Reply:", "Correction:", "Why:"]: # Try to find "Reply:" block if "Reply:" in reply: speak_text = reply.split("Reply:", 1)[1].strip() break wav_path = text_to_speech(speak_text) updated = (history or []) + [ {"role": "user", "content": user_text}, {"role": "assistant", "content": reply}, ] return updated, wav_path, "" with gr.Blocks(title="Voice Coach") as demo: gr.Markdown("## 🎙️ Interactive Voice Chat") with gr.Row(): audio_in = gr.Audio(sources=["microphone"], type="filepath", label="Speak") audio_out = gr.Audio(label="Assistant (TTS)", autoplay=True) chatbox = gr.Chatbot(type="messages", height=300) status = gr.Markdown() btn = gr.Button("Send") # Use continuous recording or press "Send" after recording audio_in.change(pipeline, inputs=[audio_in, chatbox], outputs=[chatbox, audio_out, status]) btn.click(pipeline, inputs=[audio_in, chatbox], outputs=[chatbox, audio_out, status]) if __name__ == "__main__": demo.launch()