Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -93,21 +93,21 @@ def GenerateMIDI(num_tok, idrums, iinstr):
|
|
| 93 |
model = TransformerWrapper(
|
| 94 |
num_tokens=3088,
|
| 95 |
max_seq_len=SEQ_LEN,
|
| 96 |
-
attn_layers=Decoder(dim=1024, depth=16, heads=8)
|
| 97 |
)
|
| 98 |
|
| 99 |
model = AutoregressiveWrapper(model)
|
| 100 |
|
| 101 |
model = torch.nn.DataParallel(model)
|
| 102 |
|
| 103 |
-
model.
|
| 104 |
print('=' * 70)
|
| 105 |
|
| 106 |
print('Loading model checkpoint...')
|
| 107 |
|
| 108 |
model.load_state_dict(
|
| 109 |
torch.load('Allegro_Music_Transformer_Tiny_Trained_Model_80000_steps_0.9457_loss_0.7443_acc.pth',
|
| 110 |
-
map_location='
|
| 111 |
print('=' * 70)
|
| 112 |
|
| 113 |
model.eval()
|
|
@@ -125,13 +125,14 @@ def GenerateMIDI(num_tok, idrums, iinstr):
|
|
| 125 |
|
| 126 |
for i in range(max(1, min(512, num_tok))):
|
| 127 |
|
| 128 |
-
inp = torch.LongTensor([outy]).
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
|
|
|
| 135 |
|
| 136 |
out0 = out[0].tolist()
|
| 137 |
outy.extend(out0)
|
|
@@ -253,7 +254,7 @@ if __name__ == "__main__":
|
|
| 253 |
run_btn = gr.Button("generate", variant="primary")
|
| 254 |
interrupt_btn = gr.Button("interrupt")
|
| 255 |
|
| 256 |
-
output_midi_seq = gr.
|
| 257 |
output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
|
| 258 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
| 259 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
|
|
|
| 93 |
model = TransformerWrapper(
|
| 94 |
num_tokens=3088,
|
| 95 |
max_seq_len=SEQ_LEN,
|
| 96 |
+
attn_layers=Decoder(dim=1024, depth=16, heads=8, attn_flash=True)
|
| 97 |
)
|
| 98 |
|
| 99 |
model = AutoregressiveWrapper(model)
|
| 100 |
|
| 101 |
model = torch.nn.DataParallel(model)
|
| 102 |
|
| 103 |
+
model.cuda()
|
| 104 |
print('=' * 70)
|
| 105 |
|
| 106 |
print('Loading model checkpoint...')
|
| 107 |
|
| 108 |
model.load_state_dict(
|
| 109 |
torch.load('Allegro_Music_Transformer_Tiny_Trained_Model_80000_steps_0.9457_loss_0.7443_acc.pth',
|
| 110 |
+
map_location='cuda'))
|
| 111 |
print('=' * 70)
|
| 112 |
|
| 113 |
model.eval()
|
|
|
|
| 125 |
|
| 126 |
for i in range(max(1, min(512, num_tok))):
|
| 127 |
|
| 128 |
+
inp = torch.LongTensor([outy]).cuda()
|
| 129 |
|
| 130 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.float16)
|
| 131 |
+
out = model.module.generate(inp,
|
| 132 |
+
1,
|
| 133 |
+
temperature=0.9,
|
| 134 |
+
return_prime=False,
|
| 135 |
+
verbose=False)
|
| 136 |
|
| 137 |
out0 = out[0].tolist()
|
| 138 |
outy.extend(out0)
|
|
|
|
| 254 |
run_btn = gr.Button("generate", variant="primary")
|
| 255 |
interrupt_btn = gr.Button("interrupt")
|
| 256 |
|
| 257 |
+
output_midi_seq = gr.HTML()
|
| 258 |
output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
|
| 259 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
| 260 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|