Skip to content

Commit 74fa169

Browse files
committed
Gru updates
1 parent 5ff77b2 commit 74fa169

File tree

1 file changed

+33
-8
lines changed

1 file changed

+33
-8
lines changed

keras-audio/gru-composer.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
network for training """
33
import glob
44
import pickle
5-
import numpy
5+
import numpy as np
66
from music21 import converter, instrument, note, chord, stream
77
import os
88
from keras.models import Sequential
@@ -12,10 +12,18 @@
1212
from keras.layers import Activation
1313
from keras.utils import np_utils
1414
from keras.callbacks import ModelCheckpoint, Callback
15+
import subprocess
1516
import wandb
1617
import base64
1718
wandb.init()
1819

20+
def ensure_midi(dataset="mario"):
21+
if not os.path.exists("data/%s" % dataset):
22+
print("Downloading %s dataset..." % dataset)
23+
subprocess.check_output(
24+
"curl -SL https://storage.googleapis.com/wandb/%s.tar.gz | tar xz" % dataset, shell=True) #finalfantasy
25+
open("data/%s" % dataset, "w").close()
26+
1927

2028
def train_network():
2129
""" Train a Neural Network to generate music """
@@ -38,7 +46,11 @@ def get_notes():
3846
return pickle.load(open("data/notes", "rb"))
3947

4048
for file in glob.glob("midi_songs/*.mid"):
41-
midi = converter.parse(file)
49+
try:
50+
midi = converter.parse(file)
51+
except TypeError:
52+
print("Invalid file %s" % file)
53+
continue
4254

4355
print("Parsing %s" % file)
4456

@@ -86,7 +98,7 @@ def prepare_sequences(notes, n_vocab):
8698
n_patterns = len(network_input)
8799

88100
# reshape the input into a format compatible with LSTM layers
89-
network_input = numpy.reshape(
101+
network_input = np.reshape(
90102
network_input, (n_patterns, sequence_length, 1))
91103
# normalize input
92104
network_input = network_input / float(n_vocab)
@@ -118,27 +130,39 @@ def create_network(network_input, n_vocab):
118130

119131

120132
class Midi(Callback):
133+
"""
134+
Callback for sampling a midi file
135+
"""
136+
def sample(self, preds, temperature=1.0):
137+
# helper function to sample an index from a probability array
138+
preds = np.asarray(preds).astype('float64')
139+
preds = np.log(preds) / temperature
140+
exp_preds = np.exp(preds)
141+
preds = exp_preds / np.sum(exp_preds)
142+
probas = np.random.multinomial(1, preds, 1)
143+
return np.argmax(probas)
144+
121145
def generate_notes(self, network_input, pitchnames, n_vocab):
122146
""" Generate notes from the neural network based on a sequence of notes """
123147
# pick a random sequence from the input as a starting point for the prediction
124148
model = self.model
125-
start = numpy.random.randint(0, len(network_input)-1)
149+
start = np.random.randint(0, len(network_input)-1)
126150

127151
int_to_note = dict((number, note)
128152
for number, note in enumerate(pitchnames))
129153

130-
pattern = network_input[start]
154+
pattern = list(network_input[start])
131155
prediction_output = []
132156

133157
# generate 500 notes
134158
for note_index in range(500):
135-
prediction_input = numpy.reshape(pattern, (1, len(pattern), 1))
159+
prediction_input = np.reshape(pattern, (1, len(pattern), 1))
136160
prediction_input = prediction_input / float(n_vocab)
137161

138162
prediction = model.predict(prediction_input, verbose=0)
139163

140164
# TODO: add random picking
141-
index = numpy.argmax(prediction)
165+
index = np.argmax(prediction)#self.sample(prediction)#np.argmax
142166
result = int_to_note[index]
143167
prediction_output.append(result)
144168

@@ -190,7 +214,7 @@ def on_epoch_end(self, *args):
190214
notes, n_vocab)
191215
music = self.generate_notes(network_input, pitchnames, n_vocab)
192216
midi = self.create_midi(music)
193-
midi.seek(0)
217+
midi = open(midi, "rb")
194218
data = "data:audio/midi;base64,%s" % base64.b64encode(
195219
midi.read()).decode("utf8")
196220
wandb.log({
@@ -219,4 +243,5 @@ def train(model, network_input, network_output):
219243

220244

221245
if __name__ == '__main__':
246+
ensure_midi("finalfantasy")
222247
train_network()

0 commit comments

Comments
 (0)