import torchaudio import matplotlib.pyplot as plt import gradio as gr from transformers import AutoProcessor, MusicgenForConditionalGeneration import torchaudio.transforms as T # Define available model options for dropdown model_options = { "Small Model": "facebook/musicgen-small", "Medium Model": "facebook/musicgen-medium", "Large Model": "facebook/musicgen-large" } # Define style tags options style_tags_options = ["East Coast", "Trap", "Boom Bap", "Lo-Fi", "Experimental", "Rock", "Electronic", "Pop", "Country", "Heavy Metal", "Classical", "Jazz", "Reggae"] def generate_spectrogram(audio_tensor, sample_rate): griffinlim_transform = T.GriffinLim(n_fft=400, win_length=400, hop_length=160) waveform = griffinlim_transform(audio_tensor) plt.figure(figsize=(10, 4)) plt.specgram(waveform.numpy()[0], Fs=sample_rate, cmap='viridis') plt.colorbar(format='%+2.0f dB') plt.title('Spectrogram') plt.tight_layout() plt.ylabel('Frequency (Hz)') plt.xlabel('Time (s)') spectrogram_path = "generated_spectrogram.png" plt.savefig(spectrogram_path) plt.close() return spectrogram_path def generate_music(description, model_choice, style_tags, tempo, intensity, duration): try: processor = AutoProcessor.from_pretrained(model_options[model_choice]) model = MusicgenForConditionalGeneration.from_pretrained(model_options[model_choice]) # Convert the list of selected style tags into a single string style_tags_str = " ".join(style_tags) inputs = processor(text=[description + " " + style_tags_str], return_tensors="pt", padding=True) audio_output = model.generate(**inputs, max_new_tokens=256) sampling_rate = 16000 output_file = "generated_music.wav" torchaudio.save(output_file, audio_output[0].cpu(), sampling_rate) spectrogram_path = generate_spectrogram(audio_output[0].squeeze(), sampling_rate) return output_file, spectrogram_path, None except Exception as e: error_message = f"An error occurred: {str(e)}" return None, None, error_message iface = gr.Interface( fn=generate_music, inputs=[ gr.Textbox(label="Enter a description for the music"), gr.Dropdown(label="Select Model", choices=list(model_options.keys())), gr.CheckboxGroup(label="Style Tags", choices=style_tags_options), gr.Slider(label="Tempo", minimum=60, maximum=240, step=1, value=120), gr.Slider(label="Intensity", minimum=1, maximum=10, step=1, value=5), gr.Slider(label="Duration (Seconds)", minimum=15, maximum=300, step=1, value=60) ], outputs=[ gr.Audio(label="Generated Music"), gr.Image(label="Spectrogram"), gr.Textbox(label="Error Message", visible=True) ], title="MusicGen Pro XL", description="Generate original music from multiple genres with customizable parameters and style tags. Listen to the generated music, visualize the spectrogram, and receive error messages if any." ) iface.launch()