|
| 1 | +""" This module prepares midi file data and feeds it to the neural |
| 2 | + network for training """ |
| 3 | +import glob |
| 4 | +import pickle |
| 5 | +import numpy |
| 6 | +from music21 import converter, instrument, note, chord, stream |
| 7 | +import os |
| 8 | +from keras.models import Sequential |
| 9 | +from keras.layers import Dense |
| 10 | +from keras.layers import Dropout |
| 11 | +from keras.layers import LSTM, CuDNNGRU |
| 12 | +from keras.layers import Activation |
| 13 | +from keras.utils import np_utils |
| 14 | +from keras.callbacks import ModelCheckpoint, Callback |
| 15 | +import wandb |
| 16 | +import base64 |
| 17 | +wandb.init() |
| 18 | + |
| 19 | + |
| 20 | +def train_network(): |
| 21 | + """ Train a Neural Network to generate music """ |
| 22 | + notes = get_notes() |
| 23 | + |
| 24 | + # get amount of pitch names |
| 25 | + n_vocab = len(set(notes)) |
| 26 | + |
| 27 | + network_input, network_output = prepare_sequences(notes, n_vocab) |
| 28 | + |
| 29 | + model = create_network(network_input, n_vocab) |
| 30 | + |
| 31 | + train(model, network_input, network_output) |
| 32 | + |
| 33 | + |
| 34 | +def get_notes(): |
| 35 | + """ Get all the notes and chords from the midi files in the ./midi_songs directory """ |
| 36 | + notes = [] |
| 37 | + if os.path.exists("data/notes"): |
| 38 | + return pickle.load(open("data/notes", "rb")) |
| 39 | + |
| 40 | + for file in glob.glob("midi_songs/*.mid"): |
| 41 | + midi = converter.parse(file) |
| 42 | + |
| 43 | + print("Parsing %s" % file) |
| 44 | + |
| 45 | + notes_to_parse = None |
| 46 | + |
| 47 | + try: # file has instrument parts |
| 48 | + s2 = instrument.partitionByInstrument(midi) |
| 49 | + notes_to_parse = s2.parts[0].recurse() |
| 50 | + except: # file has notes in a flat structure |
| 51 | + notes_to_parse = midi.flat.notes |
| 52 | + |
| 53 | + for element in notes_to_parse: |
| 54 | + if isinstance(element, note.Note): |
| 55 | + notes.append(str(element.pitch)) |
| 56 | + elif isinstance(element, chord.Chord): |
| 57 | + notes.append('.'.join(str(n) for n in element.normalOrder)) |
| 58 | + |
| 59 | + with open('data/notes', 'wb') as filepath: |
| 60 | + pickle.dump(notes, filepath) |
| 61 | + |
| 62 | + return notes |
| 63 | + |
| 64 | + |
| 65 | +def prepare_sequences(notes, n_vocab): |
| 66 | + """ Prepare the sequences used by the Neural Network """ |
| 67 | + sequence_length = 100 |
| 68 | + |
| 69 | + # get all pitch names |
| 70 | + pitchnames = sorted(set(item for item in notes)) |
| 71 | + |
| 72 | + # create a dictionary to map pitches to integers |
| 73 | + note_to_int = dict((note, number) |
| 74 | + for number, note in enumerate(pitchnames)) |
| 75 | + |
| 76 | + network_input = [] |
| 77 | + network_output = [] |
| 78 | + |
| 79 | + # create input sequences and the corresponding outputs |
| 80 | + for i in range(0, len(notes) - sequence_length, 1): |
| 81 | + sequence_in = notes[i:i + sequence_length] |
| 82 | + sequence_out = notes[i + sequence_length] |
| 83 | + network_input.append([note_to_int[char] for char in sequence_in]) |
| 84 | + network_output.append(note_to_int[sequence_out]) |
| 85 | + |
| 86 | + n_patterns = len(network_input) |
| 87 | + |
| 88 | + # reshape the input into a format compatible with LSTM layers |
| 89 | + network_input = numpy.reshape( |
| 90 | + network_input, (n_patterns, sequence_length, 1)) |
| 91 | + # normalize input |
| 92 | + network_input = network_input / float(n_vocab) |
| 93 | + |
| 94 | + network_output = np_utils.to_categorical(network_output) |
| 95 | + |
| 96 | + return (network_input, network_output) |
| 97 | + |
| 98 | + |
| 99 | +def create_network(network_input, n_vocab): |
| 100 | + """ create the structure of the neural network """ |
| 101 | + model = Sequential() |
| 102 | + model.add(CuDNNGRU( |
| 103 | + 256, |
| 104 | + input_shape=(network_input.shape[1], network_input.shape[2]), |
| 105 | + return_sequences=True |
| 106 | + )) |
| 107 | + model.add(Dropout(0.3)) |
| 108 | + model.add(CuDNNGRU(128, return_sequences=True)) |
| 109 | + model.add(Dropout(0.3)) |
| 110 | + model.add(CuDNNGRU(64)) |
| 111 | + model.add(Dense(256)) |
| 112 | + model.add(Dropout(0.3)) |
| 113 | + model.add(Dense(n_vocab)) |
| 114 | + model.add(Activation('softmax')) |
| 115 | + model.compile(loss='categorical_crossentropy', optimizer='rmsprop') |
| 116 | + |
| 117 | + return model |
| 118 | + |
| 119 | + |
| 120 | +class Midi(Callback): |
| 121 | + def generate_notes(self, network_input, pitchnames, n_vocab): |
| 122 | + """ Generate notes from the neural network based on a sequence of notes """ |
| 123 | + # pick a random sequence from the input as a starting point for the prediction |
| 124 | + model = self.model |
| 125 | + start = numpy.random.randint(0, len(network_input)-1) |
| 126 | + |
| 127 | + int_to_note = dict((number, note) |
| 128 | + for number, note in enumerate(pitchnames)) |
| 129 | + |
| 130 | + pattern = network_input[start] |
| 131 | + prediction_output = [] |
| 132 | + |
| 133 | + # generate 500 notes |
| 134 | + for note_index in range(500): |
| 135 | + prediction_input = numpy.reshape(pattern, (1, len(pattern), 1)) |
| 136 | + prediction_input = prediction_input / float(n_vocab) |
| 137 | + |
| 138 | + prediction = model.predict(prediction_input, verbose=0) |
| 139 | + |
| 140 | + # TODO: add random picking |
| 141 | + index = numpy.argmax(prediction) |
| 142 | + result = int_to_note[index] |
| 143 | + prediction_output.append(result) |
| 144 | + |
| 145 | + pattern.append(index) |
| 146 | + pattern = pattern[1:len(pattern)] |
| 147 | + |
| 148 | + return prediction_output |
| 149 | + |
| 150 | + def create_midi(self, prediction_output): |
| 151 | + """ convert the output from the prediction to notes and create a midi file |
| 152 | + from the notes """ |
| 153 | + offset = 0 |
| 154 | + output_notes = [] |
| 155 | + |
| 156 | + # create note and chord objects based on the values generated by the model |
| 157 | + for pattern in prediction_output: |
| 158 | + # pattern is a chord |
| 159 | + if ('.' in pattern) or pattern.isdigit(): |
| 160 | + notes_in_chord = pattern.split('.') |
| 161 | + notes = [] |
| 162 | + for current_note in notes_in_chord: |
| 163 | + new_note = note.Note(int(current_note)) |
| 164 | + new_note.storedInstrument = instrument.Piano() |
| 165 | + notes.append(new_note) |
| 166 | + new_chord = chord.Chord(notes) |
| 167 | + new_chord.offset = offset |
| 168 | + output_notes.append(new_chord) |
| 169 | + # pattern is a note |
| 170 | + else: |
| 171 | + new_note = note.Note(pattern) |
| 172 | + new_note.offset = offset |
| 173 | + new_note.storedInstrument = instrument.Piano() |
| 174 | + output_notes.append(new_note) |
| 175 | + |
| 176 | + # increase offset each iteration so that notes do not stack |
| 177 | + offset += 0.5 |
| 178 | + |
| 179 | + midi_stream = stream.Stream(output_notes) |
| 180 | + |
| 181 | + return midi_stream.write('midi') |
| 182 | + |
| 183 | + def on_epoch_end(self, *args): |
| 184 | + notes = get_notes() |
| 185 | + # Get all pitch names |
| 186 | + pitchnames = sorted(set(item for item in notes)) |
| 187 | + # Get all pitch names |
| 188 | + n_vocab = len(set(notes)) |
| 189 | + network_input, normalized_input = prepare_sequences( |
| 190 | + notes, n_vocab) |
| 191 | + music = self.generate_notes(network_input, pitchnames, n_vocab) |
| 192 | + midi = self.create_midi(music) |
| 193 | + midi.seek(0) |
| 194 | + data = "data:audio/midi;base64,%s" % base64.b64encode( |
| 195 | + midi.read()).decode("utf8") |
| 196 | + wandb.log({ |
| 197 | + "midi": wandb.Html(""" |
| 198 | + <script type="text/javascript" src="//www.midijs.net/lib/midi.js"></script> |
| 199 | + <button onClick="MIDIjs.play('%s')">Play midi</button> |
| 200 | + <button onClick="MIDIjs.stop()">Stop Playback</button> |
| 201 | + """ % data) |
| 202 | + }, commit=False) |
| 203 | + |
| 204 | + |
| 205 | +def train(model, network_input, network_output): |
| 206 | + """ train the neural network """ |
| 207 | + filepath = "mozart.hdf5" |
| 208 | + checkpoint = ModelCheckpoint( |
| 209 | + filepath, |
| 210 | + monitor='loss', |
| 211 | + verbose=0, |
| 212 | + save_best_only=True, |
| 213 | + mode='min' |
| 214 | + ) |
| 215 | + callbacks_list = [Midi(), wandb.keras.WandbCallback(), checkpoint] |
| 216 | + |
| 217 | + model.fit(network_input, network_output, epochs=200, |
| 218 | + batch_size=128, callbacks=callbacks_list) |
| 219 | + |
| 220 | + |
| 221 | +if __name__ == '__main__': |
| 222 | + train_network() |
0 commit comments