Spaces:
Running
on
Zero
Running
on
Zero
| import tempfile | |
| import time | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from dia.model import Dia | |
| # Model selection | |
| DIA_MODELS = { | |
| "Dhivehi Dia-1.6B": "alakxender/Dia-1.6B-dhivehi-ep1", | |
| #"Dhivehi 18k": "alakxender/Dia-1.6B-dhivehi-18k" | |
| } | |
| dia_models = {} | |
| def load_dia_model(model_id): | |
| if model_id not in dia_models: | |
| print(f"Loading model {model_id}") | |
| dia_models[model_id] = Dia.from_pretrained(model_id) | |
| print(f"Loaded model {model_id}") | |
| return dia_models[model_id] | |
| def run_inference( | |
| text_input: str, | |
| audio_prompt_input: Optional[Tuple[int, np.ndarray]], | |
| transcription_input: Optional[str], | |
| max_new_tokens: int, | |
| cfg_scale: float, | |
| temperature: float, | |
| top_p: float, | |
| cfg_filter_top_k: int, | |
| speed_factor: float, | |
| model_name: str, | |
| ): | |
| model_id = DIA_MODELS[model_name] | |
| model = load_dia_model(model_id) | |
| if not text_input or text_input.isspace(): | |
| raise gr.Error("Text input cannot be empty.") | |
| temp_txt_file_path = None | |
| temp_audio_prompt_path = None | |
| output_audio = (44100, np.zeros(1, dtype=np.float32)) | |
| try: | |
| prompt_path_for_generate = None | |
| if audio_prompt_input is not None: | |
| sr, audio_data = audio_prompt_input | |
| duration_sec = len(audio_data) / float(sr) if sr else 0 | |
| if duration_sec > 10.0: | |
| raise gr.Error("Audio prompt must be 10 seconds or shorter.") | |
| if ( | |
| audio_data is None or audio_data.size == 0 or audio_data.max() == 0 | |
| ): | |
| gr.Warning("Audio prompt seems empty or silent, ignoring prompt.") | |
| else: | |
| with tempfile.NamedTemporaryFile( | |
| mode="wb", suffix=".wav", delete=False | |
| ) as f_audio: | |
| temp_audio_prompt_path = f_audio.name | |
| if np.issubdtype(audio_data.dtype, np.integer): | |
| max_val = np.iinfo(audio_data.dtype).max | |
| audio_data = audio_data.astype(np.float32) / max_val | |
| elif not np.issubdtype(audio_data.dtype, np.floating): | |
| gr.Warning( | |
| f"Unsupported audio prompt dtype {audio_data.dtype}, attempting conversion." | |
| ) | |
| try: | |
| audio_data = audio_data.astype(np.float32) | |
| except Exception as conv_e: | |
| raise gr.Error( | |
| f"Failed to convert audio prompt to float32: {conv_e}" | |
| ) | |
| if audio_data.ndim > 1: | |
| if audio_data.shape[0] == 2: | |
| audio_data = np.mean(audio_data, axis=0) | |
| elif audio_data.shape[1] == 2: | |
| audio_data = np.mean(audio_data, axis=1) | |
| else: | |
| gr.Warning( | |
| f"Audio prompt has unexpected shape {audio_data.shape}, taking first channel/axis." | |
| ) | |
| audio_data = ( | |
| audio_data[0] | |
| if audio_data.shape[0] < audio_data.shape[1] | |
| else audio_data[:, 0] | |
| ) | |
| audio_data = np.ascontiguousarray(audio_data) | |
| try: | |
| sf.write( | |
| temp_audio_prompt_path, audio_data, sr, subtype="FLOAT" | |
| ) | |
| prompt_path_for_generate = temp_audio_prompt_path | |
| print( | |
| f"Created temporary audio prompt file: {temp_audio_prompt_path} (orig sr: {sr})" | |
| ) | |
| except Exception as write_e: | |
| print(f"Error writing temporary audio file: {write_e}") | |
| raise gr.Error(f"Failed to save audio prompt: {write_e}") | |
| start_time = time.time() | |
| with torch.inference_mode(): | |
| combined_text = ( | |
| text_input.strip() + "\n" + transcription_input.strip() | |
| if transcription_input and not transcription_input.isspace() | |
| else text_input | |
| ) | |
| output_audio_np = model.generate( | |
| combined_text, | |
| max_tokens=max_new_tokens, | |
| cfg_scale=cfg_scale, | |
| temperature=temperature, | |
| top_p=top_p, | |
| cfg_filter_top_k=cfg_filter_top_k, | |
| use_torch_compile=False, | |
| audio_prompt_path=prompt_path_for_generate, | |
| ) | |
| end_time = time.time() | |
| print(f"Generation finished in {end_time - start_time:.2f} seconds.") | |
| if output_audio_np is not None: | |
| output_sr = 44100 | |
| original_len = len(output_audio_np) | |
| speed_factor = max(0.1, min(speed_factor, 5.0)) | |
| target_len = int(original_len / speed_factor) | |
| if target_len != original_len and target_len > 0: | |
| x_original = np.arange(original_len) | |
| x_resampled = np.linspace(0, original_len - 1, target_len) | |
| resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np) | |
| output_audio = ( | |
| output_sr, | |
| resampled_audio_np.astype(np.float32), | |
| ) | |
| print( | |
| f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed." | |
| ) | |
| else: | |
| output_audio = ( | |
| output_sr, | |
| output_audio_np, | |
| ) | |
| print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).") | |
| print( | |
| f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}" | |
| ) | |
| if ( | |
| output_audio[1].dtype == np.float32 | |
| or output_audio[1].dtype == np.float64 | |
| ): | |
| audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0) | |
| audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16) | |
| output_audio = (output_sr, audio_for_gradio) | |
| print("Converted audio to int16 for Gradio output.") | |
| else: | |
| print("\nGeneration finished, but no valid tokens were produced.") | |
| gr.Warning("Generation produced no output.") | |
| except Exception as e: | |
| print(f"Error during inference: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise gr.Error(f"Inference failed: {e}") | |
| finally: | |
| if temp_txt_file_path and Path(temp_txt_file_path).exists(): | |
| try: | |
| Path(temp_txt_file_path).unlink() | |
| print(f"Deleted temporary text file: {temp_txt_file_path}") | |
| except OSError as e: | |
| print( | |
| f"Warning: Error deleting temporary text file {temp_txt_file_path}: {e}" | |
| ) | |
| if temp_audio_prompt_path and Path(temp_audio_prompt_path).exists(): | |
| try: | |
| Path(temp_audio_prompt_path).unlink() | |
| print(f"Deleted temporary audio prompt file: {temp_audio_prompt_path}") | |
| except OSError as e: | |
| print( | |
| f"Warning: Error deleting temporary audio prompt file {temp_audio_prompt_path}: {e}" | |
| ) | |
| return output_audio | |
| def get_dia_1_6B_tab(): | |
| css = """ | |
| #col-container {max-width: 90%; margin-left: auto; margin-right: auto;} | |
| .dhivehi-text-nofont textarea { | |
| font-size: 18px !important; | |
| line-height: 1.8 !important; | |
| direction: rtl !important; | |
| text-align: right !important; | |
| } | |
| .dhivehi-text-nofont input { | |
| font-size: 18px !important; | |
| direction: rtl !important; | |
| text-align: right !important; | |
| } | |
| """ | |
| default_text = "" | |
| example_txt_path = Path("./example.txt") | |
| if example_txt_path.exists(): | |
| try: | |
| default_text = example_txt_path.read_text(encoding="utf-8").strip() | |
| if not default_text: | |
| default_text = "Example text file was empty." | |
| except Exception as e: | |
| print(f"Warning: Could not read example.txt: {e}") | |
| with gr.Tab("🎙️ Dia-1.6B"): | |
| gr.Markdown("# Dia Text-to-Speech Synthesis (Dia-1.6B)") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(DIA_MODELS.keys()), | |
| value=list(DIA_MODELS.keys())[0], | |
| label="Select Dia Model" | |
| ) | |
| text_input = gr.Textbox( | |
| label="Input Text", | |
| placeholder="ލިޔެލަން", | |
| value=default_text, | |
| lines=5, | |
| elem_classes=["dhivehi-text-nofont"] | |
| ) | |
| audio_prompt_input = gr.Audio( | |
| label="Audio Prompt (≤ 10 s, Optional)", | |
| show_label=True, | |
| sources=["upload", "microphone"], | |
| type="numpy", | |
| ) | |
| transcription_input = gr.Textbox( | |
| label="Audio Prompt Transcription (Optional)", | |
| placeholder="ޓްރާންސްކްރިޕްޓް ލިޔެލަން", | |
| lines=3, | |
| elem_classes=["dhivehi-text-nofont"] | |
| ) | |
| with gr.Accordion("Generation Parameters", open=False): | |
| default_model = load_dia_model(DIA_MODELS[list(DIA_MODELS.keys())[0]]) | |
| max_new_tokens = gr.Slider( | |
| label="Max New Tokens (Audio Length)", | |
| minimum=860, | |
| maximum=3072, | |
| value=getattr(getattr(default_model.config, 'data', None), 'audio_length', 1536), | |
| step=50, | |
| info="Controls the maximum length of the generated audio (more tokens = longer audio).", | |
| ) | |
| cfg_scale = gr.Slider( | |
| label="CFG Scale (Guidance Strength)", | |
| minimum=1.0, | |
| maximum=5.0, | |
| value=3.0, | |
| step=0.1, | |
| info="Higher values increase adherence to the text prompt.", | |
| ) | |
| temperature = gr.Slider( | |
| label="Temperature (Randomness)", | |
| minimum=1.0, | |
| maximum=2.5, | |
| value=1.8, | |
| step=0.05, | |
| info="Lower values make the output more deterministic, higher values increase randomness.", | |
| ) | |
| top_p = gr.Slider( | |
| label="Top P (Nucleus Sampling)", | |
| minimum=0.70, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.01, | |
| info="Filters vocabulary to the most likely tokens cumulatively reaching probability P.", | |
| ) | |
| cfg_filter_top_k = gr.Slider( | |
| label="CFG Filter Top K", | |
| minimum=15, | |
| maximum=100, | |
| value=45, | |
| step=1, | |
| info="Top k filter for CFG guidance.", | |
| ) | |
| speed_factor_slider = gr.Slider( | |
| label="Speed Factor", | |
| minimum=0.8, | |
| maximum=1.0, | |
| value=1.0, | |
| step=0.02, | |
| info="Adjusts the speed of the generated audio (1.0 = original speed).", | |
| ) | |
| generate_btn = gr.Button("Generate Audio", variant="primary") | |
| with gr.Column(scale=1): | |
| audio_output = gr.Audio( | |
| label="Generated Audio", | |
| type="numpy", | |
| autoplay=False, | |
| ) | |
| generate_btn.click( | |
| run_inference, | |
| inputs=[ | |
| text_input, | |
| audio_prompt_input, | |
| transcription_input, | |
| max_new_tokens, | |
| cfg_scale, | |
| temperature, | |
| top_p, | |
| cfg_filter_top_k, | |
| speed_factor_slider, | |
| model_dropdown, | |
| ], | |
| outputs=[audio_output], | |
| ) | |
| # Examples (optional, can be extended) | |
| examples_list = [ | |
| [ | |
| """[S1] އައްސަލާމު އަލައިކުމް. (clears throat) Good morning! | |
| [S2] How are you today? | |
| [S1] I'm fine, thanks. ކިހިނެއް ހާލު؟ | |
| [S2] (coughs) ކުޑަކޮށް ބަލިކޮށް މިއުޅެނީ | |
| [S1] Oh okay. Get well soon... | |
| [S2] Thanks! See you later... ފަހުން ދިމާވެލާނީ | |
| [S1]""", | |
| None, | |
| "", | |
| 1536, | |
| 5.0, | |
| 2.5, | |
| 0.95, | |
| 45, | |
| 1.0, | |
| list(DIA_MODELS.keys())[0], | |
| ], | |
| ["""[FEMALE-01] [S1] ގައުމަށް އައި މިނިވަން ނޫރާނީ... [S2] ދައުރުން މި ހަނދާންތައް އާކުރަނީ... [S1] އައުދާނަ އިތުރު އަބުޠާލުންނަށް... [S2] ޒިކުރާގެ މަލުން މި ވެދުން ކުރަނީ.""", | |
| None, | |
| "", | |
| 1536, | |
| 3.0, | |
| 1.8, | |
| 0.95, | |
| 45, | |
| 0.96, | |
| list(DIA_MODELS.keys())[0] | |
| ], | |
| ["""[MALE-01] [S1] މާޒީގެ އުޖާލާ މަންޒަރުތައް... [S2] މާރީތި އުފާވެރި ކުރެހުންތައް... [S2] ދާތީ އަދު ހާމަ ވަމުން ކުލަތައް... [S2] ތާރީޚު އަލުން މި އިޢާދަ ވަނީ.""", | |
| None, | |
| "", | |
| 1536, | |
| 3.0, | |
| 1.8, | |
| 0.95, | |
| 45, | |
| 0.96, | |
| list(DIA_MODELS.keys())[0] | |
| ], | |
| ] | |
| if examples_list: | |
| gr.Examples( | |
| examples=examples_list, | |
| inputs=[ | |
| text_input, | |
| audio_prompt_input, | |
| transcription_input, | |
| max_new_tokens, | |
| cfg_scale, | |
| temperature, | |
| top_p, | |
| cfg_filter_top_k, | |
| speed_factor_slider, | |
| model_dropdown, | |
| ], | |
| outputs=[audio_output], | |
| fn=run_inference, | |
| cache_examples=False, | |
| label="Examples (Click to Run)", | |
| ) | |
| else: | |
| gr.Markdown("_(No examples configured or example prompt file missing)_") | |
| gr.Markdown( | |
| "---\n" | |
| "**General Guidelines:**\n" | |
| "- Keep input text length moderate\n" | |
| " - Short input (corresponding to under 5s of audio) will sound unnatural\n" | |
| " - Very long input (corresponding to over 20s of audio) will make the speech unnaturally fast\n\n" | |
| "- Use non-verbal tags sparingly, from the list in the README. Overusing or using unlisted non-verbals may cause weird artifacts\n\n" | |
| "- Always begin input text with [S1], and always alternate between [S1] and [S2] (i.e. [S1]... [S1]... is not good)\n\n" | |
| "**When using audio prompts (voice cloning):**\n" | |
| "- Provide the transcript of the to-be cloned audio before the generation text\n" | |
| "- Transcript must use [S1], [S2] speaker tags correctly:\n" | |
| " - Single speaker: [S1]...\n" | |
| " - Two speakers: [S1]... [S2]...\n" | |
| "- Duration of the to-be cloned audio should be 5~10 seconds for the best results\n" | |
| " - (Keep in mind: 1 second ≈ 86 tokens)\n" | |
| "- Put [S1] or [S2] (the second-to-last speaker's tag) at the end of the audio to improve audio quality at the end" | |
| ) | |
| # No explicit return needed for context manager pattern |