Skip to content

Commit 75e1ab5

Browse files
committed
More class material
1 parent 12939e8 commit 75e1ab5

File tree

3 files changed

+454
-18
lines changed

3 files changed

+454
-18
lines changed

keras-audio/gru-composer.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
""" This module prepares midi file data and feeds it to the neural
2+
network for training """
3+
import glob
4+
import pickle
5+
import numpy
6+
from music21 import converter, instrument, note, chord, stream
7+
import os
8+
from keras.models import Sequential
9+
from keras.layers import Dense
10+
from keras.layers import Dropout
11+
from keras.layers import LSTM, CuDNNGRU
12+
from keras.layers import Activation
13+
from keras.utils import np_utils
14+
from keras.callbacks import ModelCheckpoint, Callback
15+
import wandb
16+
import base64
17+
wandb.init()
18+
19+
20+
def train_network():
21+
""" Train a Neural Network to generate music """
22+
notes = get_notes()
23+
24+
# get amount of pitch names
25+
n_vocab = len(set(notes))
26+
27+
network_input, network_output = prepare_sequences(notes, n_vocab)
28+
29+
model = create_network(network_input, n_vocab)
30+
31+
train(model, network_input, network_output)
32+
33+
34+
def get_notes():
35+
""" Get all the notes and chords from the midi files in the ./midi_songs directory """
36+
notes = []
37+
if os.path.exists("data/notes"):
38+
return pickle.load(open("data/notes", "rb"))
39+
40+
for file in glob.glob("midi_songs/*.mid"):
41+
midi = converter.parse(file)
42+
43+
print("Parsing %s" % file)
44+
45+
notes_to_parse = None
46+
47+
try: # file has instrument parts
48+
s2 = instrument.partitionByInstrument(midi)
49+
notes_to_parse = s2.parts[0].recurse()
50+
except: # file has notes in a flat structure
51+
notes_to_parse = midi.flat.notes
52+
53+
for element in notes_to_parse:
54+
if isinstance(element, note.Note):
55+
notes.append(str(element.pitch))
56+
elif isinstance(element, chord.Chord):
57+
notes.append('.'.join(str(n) for n in element.normalOrder))
58+
59+
with open('data/notes', 'wb') as filepath:
60+
pickle.dump(notes, filepath)
61+
62+
return notes
63+
64+
65+
def prepare_sequences(notes, n_vocab):
66+
""" Prepare the sequences used by the Neural Network """
67+
sequence_length = 100
68+
69+
# get all pitch names
70+
pitchnames = sorted(set(item for item in notes))
71+
72+
# create a dictionary to map pitches to integers
73+
note_to_int = dict((note, number)
74+
for number, note in enumerate(pitchnames))
75+
76+
network_input = []
77+
network_output = []
78+
79+
# create input sequences and the corresponding outputs
80+
for i in range(0, len(notes) - sequence_length, 1):
81+
sequence_in = notes[i:i + sequence_length]
82+
sequence_out = notes[i + sequence_length]
83+
network_input.append([note_to_int[char] for char in sequence_in])
84+
network_output.append(note_to_int[sequence_out])
85+
86+
n_patterns = len(network_input)
87+
88+
# reshape the input into a format compatible with LSTM layers
89+
network_input = numpy.reshape(
90+
network_input, (n_patterns, sequence_length, 1))
91+
# normalize input
92+
network_input = network_input / float(n_vocab)
93+
94+
network_output = np_utils.to_categorical(network_output)
95+
96+
return (network_input, network_output)
97+
98+
99+
def create_network(network_input, n_vocab):
100+
""" create the structure of the neural network """
101+
model = Sequential()
102+
model.add(CuDNNGRU(
103+
256,
104+
input_shape=(network_input.shape[1], network_input.shape[2]),
105+
return_sequences=True
106+
))
107+
model.add(Dropout(0.3))
108+
model.add(CuDNNGRU(128, return_sequences=True))
109+
model.add(Dropout(0.3))
110+
model.add(CuDNNGRU(64))
111+
model.add(Dense(256))
112+
model.add(Dropout(0.3))
113+
model.add(Dense(n_vocab))
114+
model.add(Activation('softmax'))
115+
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
116+
117+
return model
118+
119+
120+
class Midi(Callback):
121+
def generate_notes(self, network_input, pitchnames, n_vocab):
122+
""" Generate notes from the neural network based on a sequence of notes """
123+
# pick a random sequence from the input as a starting point for the prediction
124+
model = self.model
125+
start = numpy.random.randint(0, len(network_input)-1)
126+
127+
int_to_note = dict((number, note)
128+
for number, note in enumerate(pitchnames))
129+
130+
pattern = network_input[start]
131+
prediction_output = []
132+
133+
# generate 500 notes
134+
for note_index in range(500):
135+
prediction_input = numpy.reshape(pattern, (1, len(pattern), 1))
136+
prediction_input = prediction_input / float(n_vocab)
137+
138+
prediction = model.predict(prediction_input, verbose=0)
139+
140+
# TODO: add random picking
141+
index = numpy.argmax(prediction)
142+
result = int_to_note[index]
143+
prediction_output.append(result)
144+
145+
pattern.append(index)
146+
pattern = pattern[1:len(pattern)]
147+
148+
return prediction_output
149+
150+
def create_midi(self, prediction_output):
151+
""" convert the output from the prediction to notes and create a midi file
152+
from the notes """
153+
offset = 0
154+
output_notes = []
155+
156+
# create note and chord objects based on the values generated by the model
157+
for pattern in prediction_output:
158+
# pattern is a chord
159+
if ('.' in pattern) or pattern.isdigit():
160+
notes_in_chord = pattern.split('.')
161+
notes = []
162+
for current_note in notes_in_chord:
163+
new_note = note.Note(int(current_note))
164+
new_note.storedInstrument = instrument.Piano()
165+
notes.append(new_note)
166+
new_chord = chord.Chord(notes)
167+
new_chord.offset = offset
168+
output_notes.append(new_chord)
169+
# pattern is a note
170+
else:
171+
new_note = note.Note(pattern)
172+
new_note.offset = offset
173+
new_note.storedInstrument = instrument.Piano()
174+
output_notes.append(new_note)
175+
176+
# increase offset each iteration so that notes do not stack
177+
offset += 0.5
178+
179+
midi_stream = stream.Stream(output_notes)
180+
181+
return midi_stream.write('midi')
182+
183+
def on_epoch_end(self, *args):
184+
notes = get_notes()
185+
# Get all pitch names
186+
pitchnames = sorted(set(item for item in notes))
187+
# Get all pitch names
188+
n_vocab = len(set(notes))
189+
network_input, normalized_input = prepare_sequences(
190+
notes, n_vocab)
191+
music = self.generate_notes(network_input, pitchnames, n_vocab)
192+
midi = self.create_midi(music)
193+
midi.seek(0)
194+
data = "data:audio/midi;base64,%s" % base64.b64encode(
195+
midi.read()).decode("utf8")
196+
wandb.log({
197+
"midi": wandb.Html("""
198+
<script type="text/javascript" src="//www.midijs.net/lib/midi.js"></script>
199+
<button onClick="MIDIjs.play('%s')">Play midi</button>
200+
<button onClick="MIDIjs.stop()">Stop Playback</button>
201+
""" % data)
202+
}, commit=False)
203+
204+
205+
def train(model, network_input, network_output):
206+
""" train the neural network """
207+
filepath = "mozart.hdf5"
208+
checkpoint = ModelCheckpoint(
209+
filepath,
210+
monitor='loss',
211+
verbose=0,
212+
save_best_only=True,
213+
mode='min'
214+
)
215+
callbacks_list = [Midi(), wandb.keras.WandbCallback(), checkpoint]
216+
217+
model.fit(network_input, network_output, epochs=200,
218+
batch_size=128, callbacks=callbacks_list)
219+
220+
221+
if __name__ == '__main__':
222+
train_network()

0 commit comments

Comments
 (0)