File size: 3,104 Bytes
dc618e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torchaudio
import matplotlib.pyplot as plt
import gradio as gr
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torchaudio.transforms as T

# Define available model options for dropdown
model_options = {
    "Small Model": "facebook/musicgen-small",
    "Medium Model": "facebook/musicgen-medium",
    "Large Model": "facebook/musicgen-large"
}

# Define style tags options
style_tags_options = ["East Coast", "Trap", "Boom Bap", "Lo-Fi", "Experimental", 
                      "Rock", "Electronic", "Pop", "Country", "Heavy Metal", 
                      "Classical", "Jazz", "Reggae"]

def generate_spectrogram(audio_tensor, sample_rate):
    griffinlim_transform = T.GriffinLim(n_fft=400, win_length=400, hop_length=160)
    waveform = griffinlim_transform(audio_tensor)

    plt.figure(figsize=(10, 4))
    plt.specgram(waveform.numpy()[0], Fs=sample_rate, cmap='viridis')
    plt.colorbar(format='%+2.0f dB')
    plt.title('Spectrogram')
    plt.tight_layout()
    plt.ylabel('Frequency (Hz)')
    plt.xlabel('Time (s)')
    spectrogram_path = "generated_spectrogram.png"
    plt.savefig(spectrogram_path)
    plt.close()
    return spectrogram_path

def generate_music(description, model_choice, style_tags, tempo, intensity, duration):
    try:
        processor = AutoProcessor.from_pretrained(model_options[model_choice])
        model = MusicgenForConditionalGeneration.from_pretrained(model_options[model_choice])

        # Convert the list of selected style tags into a single string
        style_tags_str = " ".join(style_tags)

        inputs = processor(text=[description + " " + style_tags_str], return_tensors="pt", padding=True)
        audio_output = model.generate(**inputs, max_new_tokens=256)

        sampling_rate = 16000
        output_file = "generated_music.wav"
        torchaudio.save(output_file, audio_output[0].cpu(), sampling_rate)
        spectrogram_path = generate_spectrogram(audio_output[0].squeeze(), sampling_rate)

        return output_file, spectrogram_path, None
    except Exception as e:
        error_message = f"An error occurred: {str(e)}"
        return None, None, error_message

iface = gr.Interface(
    fn=generate_music,
    inputs=[
        gr.Textbox(label="Enter a description for the music"),
        gr.Dropdown(label="Select Model", choices=list(model_options.keys())),
        gr.CheckboxGroup(label="Style Tags", choices=style_tags_options),
        gr.Slider(label="Tempo", minimum=60, maximum=240, step=1, value=120),
        gr.Slider(label="Intensity", minimum=1, maximum=10, step=1, value=5),
        gr.Slider(label="Duration (Seconds)", minimum=15, maximum=300, step=1, value=60)
    ],
    outputs=[
        gr.Audio(label="Generated Music"),
        gr.Image(label="Spectrogram"),
        gr.Textbox(label="Error Message", visible=True)
    ],
    title="MusicGen Pro XL",
    description="Generate original music from multiple genres with customizable parameters and style tags. Listen to the generated music, visualize the spectrogram, and receive error messages if any."
)

iface.launch()