2
2
network for training """
3
3
import glob
4
4
import pickle
5
- import numpy
5
+ import numpy as np
6
6
from music21 import converter , instrument , note , chord , stream
7
7
import os
8
8
from keras .models import Sequential
12
12
from keras .layers import Activation
13
13
from keras .utils import np_utils
14
14
from keras .callbacks import ModelCheckpoint , Callback
15
+ import subprocess
15
16
import wandb
16
17
import base64
17
18
wandb .init ()
18
19
20
+ def ensure_midi (dataset = "mario" ):
21
+ if not os .path .exists ("data/%s" % dataset ):
22
+ print ("Downloading %s dataset..." % dataset )
23
+ subprocess .check_output (
24
+ "curl -SL https://storage.googleapis.com/wandb/%s.tar.gz | tar xz" % dataset , shell = True ) #finalfantasy
25
+ open ("data/%s" % dataset , "w" ).close ()
26
+
19
27
20
28
def train_network ():
21
29
""" Train a Neural Network to generate music """
@@ -38,7 +46,11 @@ def get_notes():
38
46
return pickle .load (open ("data/notes" , "rb" ))
39
47
40
48
for file in glob .glob ("midi_songs/*.mid" ):
41
- midi = converter .parse (file )
49
+ try :
50
+ midi = converter .parse (file )
51
+ except TypeError :
52
+ print ("Invalid file %s" % file )
53
+ continue
42
54
43
55
print ("Parsing %s" % file )
44
56
@@ -86,7 +98,7 @@ def prepare_sequences(notes, n_vocab):
86
98
n_patterns = len (network_input )
87
99
88
100
# reshape the input into a format compatible with LSTM layers
89
- network_input = numpy .reshape (
101
+ network_input = np .reshape (
90
102
network_input , (n_patterns , sequence_length , 1 ))
91
103
# normalize input
92
104
network_input = network_input / float (n_vocab )
@@ -118,27 +130,39 @@ def create_network(network_input, n_vocab):
118
130
119
131
120
132
class Midi (Callback ):
133
+ """
134
+ Callback for sampling a midi file
135
+ """
136
+ def sample (self , preds , temperature = 1.0 ):
137
+ # helper function to sample an index from a probability array
138
+ preds = np .asarray (preds ).astype ('float64' )
139
+ preds = np .log (preds ) / temperature
140
+ exp_preds = np .exp (preds )
141
+ preds = exp_preds / np .sum (exp_preds )
142
+ probas = np .random .multinomial (1 , preds , 1 )
143
+ return np .argmax (probas )
144
+
121
145
def generate_notes (self , network_input , pitchnames , n_vocab ):
122
146
""" Generate notes from the neural network based on a sequence of notes """
123
147
# pick a random sequence from the input as a starting point for the prediction
124
148
model = self .model
125
- start = numpy .random .randint (0 , len (network_input )- 1 )
149
+ start = np .random .randint (0 , len (network_input )- 1 )
126
150
127
151
int_to_note = dict ((number , note )
128
152
for number , note in enumerate (pitchnames ))
129
153
130
- pattern = network_input [start ]
154
+ pattern = list ( network_input [start ])
131
155
prediction_output = []
132
156
133
157
# generate 500 notes
134
158
for note_index in range (500 ):
135
- prediction_input = numpy .reshape (pattern , (1 , len (pattern ), 1 ))
159
+ prediction_input = np .reshape (pattern , (1 , len (pattern ), 1 ))
136
160
prediction_input = prediction_input / float (n_vocab )
137
161
138
162
prediction = model .predict (prediction_input , verbose = 0 )
139
163
140
164
# TODO: add random picking
141
- index = numpy .argmax (prediction )
165
+ index = np .argmax (prediction )#self.sample(prediction)#np.argmax
142
166
result = int_to_note [index ]
143
167
prediction_output .append (result )
144
168
@@ -190,7 +214,7 @@ def on_epoch_end(self, *args):
190
214
notes , n_vocab )
191
215
music = self .generate_notes (network_input , pitchnames , n_vocab )
192
216
midi = self .create_midi (music )
193
- midi . seek ( 0 )
217
+ midi = open ( midi , "rb" )
194
218
data = "data:audio/midi;base64,%s" % base64 .b64encode (
195
219
midi .read ()).decode ("utf8" )
196
220
wandb .log ({
@@ -219,4 +243,5 @@ def train(model, network_input, network_output):
219
243
220
244
221
245
if __name__ == '__main__' :
246
+ ensure_midi ("finalfantasy" )
222
247
train_network ()
0 commit comments