Kabatubare commited on
Commit
dc618e8
·
verified ·
1 Parent(s): 579aaa6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ import matplotlib.pyplot as plt
3
+ import gradio as gr
4
+ from transformers import AutoProcessor, MusicgenForConditionalGeneration
5
+ import torchaudio.transforms as T
6
+
7
+ # Define available model options for dropdown
8
+ model_options = {
9
+ "Small Model": "facebook/musicgen-small",
10
+ "Medium Model": "facebook/musicgen-medium",
11
+ "Large Model": "facebook/musicgen-large"
12
+ }
13
+
14
+ # Define style tags options
15
+ style_tags_options = ["East Coast", "Trap", "Boom Bap", "Lo-Fi", "Experimental",
16
+ "Rock", "Electronic", "Pop", "Country", "Heavy Metal",
17
+ "Classical", "Jazz", "Reggae"]
18
+
19
+ def generate_spectrogram(audio_tensor, sample_rate):
20
+ griffinlim_transform = T.GriffinLim(n_fft=400, win_length=400, hop_length=160)
21
+ waveform = griffinlim_transform(audio_tensor)
22
+
23
+ plt.figure(figsize=(10, 4))
24
+ plt.specgram(waveform.numpy()[0], Fs=sample_rate, cmap='viridis')
25
+ plt.colorbar(format='%+2.0f dB')
26
+ plt.title('Spectrogram')
27
+ plt.tight_layout()
28
+ plt.ylabel('Frequency (Hz)')
29
+ plt.xlabel('Time (s)')
30
+ spectrogram_path = "generated_spectrogram.png"
31
+ plt.savefig(spectrogram_path)
32
+ plt.close()
33
+ return spectrogram_path
34
+
35
+ def generate_music(description, model_choice, style_tags, tempo, intensity, duration):
36
+ try:
37
+ processor = AutoProcessor.from_pretrained(model_options[model_choice])
38
+ model = MusicgenForConditionalGeneration.from_pretrained(model_options[model_choice])
39
+
40
+ # Convert the list of selected style tags into a single string
41
+ style_tags_str = " ".join(style_tags)
42
+
43
+ inputs = processor(text=[description + " " + style_tags_str], return_tensors="pt", padding=True)
44
+ audio_output = model.generate(**inputs, max_new_tokens=256)
45
+
46
+ sampling_rate = 16000
47
+ output_file = "generated_music.wav"
48
+ torchaudio.save(output_file, audio_output[0].cpu(), sampling_rate)
49
+ spectrogram_path = generate_spectrogram(audio_output[0].squeeze(), sampling_rate)
50
+
51
+ return output_file, spectrogram_path, None
52
+ except Exception as e:
53
+ error_message = f"An error occurred: {str(e)}"
54
+ return None, None, error_message
55
+
56
+ iface = gr.Interface(
57
+ fn=generate_music,
58
+ inputs=[
59
+ gr.Textbox(label="Enter a description for the music"),
60
+ gr.Dropdown(label="Select Model", choices=list(model_options.keys())),
61
+ gr.CheckboxGroup(label="Style Tags", choices=style_tags_options),
62
+ gr.Slider(label="Tempo", minimum=60, maximum=240, step=1, value=120),
63
+ gr.Slider(label="Intensity", minimum=1, maximum=10, step=1, value=5),
64
+ gr.Slider(label="Duration (Seconds)", minimum=15, maximum=300, step=1, value=60)
65
+ ],
66
+ outputs=[
67
+ gr.Audio(label="Generated Music"),
68
+ gr.Image(label="Spectrogram"),
69
+ gr.Textbox(label="Error Message", visible=True)
70
+ ],
71
+ title="MusicGen Pro XL",
72
+ description="Generate original music from multiple genres with customizable parameters and style tags. Listen to the generated music, visualize the spectrogram, and receive error messages if any."
73
+ )
74
+
75
+ iface.launch()