Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchaudio
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
| 5 |
+
import torchaudio.transforms as T
|
| 6 |
+
|
| 7 |
+
# Define available model options for dropdown
|
| 8 |
+
model_options = {
|
| 9 |
+
"Small Model": "facebook/musicgen-small",
|
| 10 |
+
"Medium Model": "facebook/musicgen-medium",
|
| 11 |
+
"Large Model": "facebook/musicgen-large"
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
# Define style tags options
|
| 15 |
+
style_tags_options = ["East Coast", "Trap", "Boom Bap", "Lo-Fi", "Experimental",
|
| 16 |
+
"Rock", "Electronic", "Pop", "Country", "Heavy Metal",
|
| 17 |
+
"Classical", "Jazz", "Reggae"]
|
| 18 |
+
|
| 19 |
+
def generate_spectrogram(audio_tensor, sample_rate):
|
| 20 |
+
griffinlim_transform = T.GriffinLim(n_fft=400, win_length=400, hop_length=160)
|
| 21 |
+
waveform = griffinlim_transform(audio_tensor)
|
| 22 |
+
|
| 23 |
+
plt.figure(figsize=(10, 4))
|
| 24 |
+
plt.specgram(waveform.numpy()[0], Fs=sample_rate, cmap='viridis')
|
| 25 |
+
plt.colorbar(format='%+2.0f dB')
|
| 26 |
+
plt.title('Spectrogram')
|
| 27 |
+
plt.tight_layout()
|
| 28 |
+
plt.ylabel('Frequency (Hz)')
|
| 29 |
+
plt.xlabel('Time (s)')
|
| 30 |
+
spectrogram_path = "generated_spectrogram.png"
|
| 31 |
+
plt.savefig(spectrogram_path)
|
| 32 |
+
plt.close()
|
| 33 |
+
return spectrogram_path
|
| 34 |
+
|
| 35 |
+
def generate_music(description, model_choice, style_tags, tempo, intensity, duration):
|
| 36 |
+
try:
|
| 37 |
+
processor = AutoProcessor.from_pretrained(model_options[model_choice])
|
| 38 |
+
model = MusicgenForConditionalGeneration.from_pretrained(model_options[model_choice])
|
| 39 |
+
|
| 40 |
+
# Convert the list of selected style tags into a single string
|
| 41 |
+
style_tags_str = " ".join(style_tags)
|
| 42 |
+
|
| 43 |
+
inputs = processor(text=[description + " " + style_tags_str], return_tensors="pt", padding=True)
|
| 44 |
+
audio_output = model.generate(**inputs, max_new_tokens=256)
|
| 45 |
+
|
| 46 |
+
sampling_rate = 16000
|
| 47 |
+
output_file = "generated_music.wav"
|
| 48 |
+
torchaudio.save(output_file, audio_output[0].cpu(), sampling_rate)
|
| 49 |
+
spectrogram_path = generate_spectrogram(audio_output[0].squeeze(), sampling_rate)
|
| 50 |
+
|
| 51 |
+
return output_file, spectrogram_path, None
|
| 52 |
+
except Exception as e:
|
| 53 |
+
error_message = f"An error occurred: {str(e)}"
|
| 54 |
+
return None, None, error_message
|
| 55 |
+
|
| 56 |
+
iface = gr.Interface(
|
| 57 |
+
fn=generate_music,
|
| 58 |
+
inputs=[
|
| 59 |
+
gr.Textbox(label="Enter a description for the music"),
|
| 60 |
+
gr.Dropdown(label="Select Model", choices=list(model_options.keys())),
|
| 61 |
+
gr.CheckboxGroup(label="Style Tags", choices=style_tags_options),
|
| 62 |
+
gr.Slider(label="Tempo", minimum=60, maximum=240, step=1, value=120),
|
| 63 |
+
gr.Slider(label="Intensity", minimum=1, maximum=10, step=1, value=5),
|
| 64 |
+
gr.Slider(label="Duration (Seconds)", minimum=15, maximum=300, step=1, value=60)
|
| 65 |
+
],
|
| 66 |
+
outputs=[
|
| 67 |
+
gr.Audio(label="Generated Music"),
|
| 68 |
+
gr.Image(label="Spectrogram"),
|
| 69 |
+
gr.Textbox(label="Error Message", visible=True)
|
| 70 |
+
],
|
| 71 |
+
title="MusicGen Pro XL",
|
| 72 |
+
description="Generate original music from multiple genres with customizable parameters and style tags. Listen to the generated music, visualize the spectrogram, and receive error messages if any."
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
iface.launch()
|