Skip to content

Commit 6800dc8

Browse files
committed
Updates for class
1 parent 98106be commit 6800dc8

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

examples/keras-audio/gru-composer.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def get_notes():
6969
elif isinstance(element, chord.Chord):
7070
notes.append('.'.join(str(n) for n in element.normalOrder))
7171

72+
os.makedirs("data", exist_ok=True)
7273
with open('data/notes', 'wb') as filepath:
7374
pickle.dump(notes, filepath)
7475

@@ -112,18 +113,17 @@ def prepare_sequences(notes, n_vocab):
112113
def create_network(network_input, n_vocab):
113114
""" create the structure of the neural network """
114115
model = tf.keras.Sequential()
115-
model.add(tf.keras.layers.GRU(
116-
256,
116+
model.add(tf.keras.layers.CuDNNGRU(
117+
128,
117118
input_shape=(network_input.shape[1], network_input.shape[2]),
118119
return_sequences=True
119120
))
120-
model.add(tf.keras.layers.Dropout(0.3))
121-
model.add(tf.keras.layers.GRU(128, return_sequences=True))
122-
model.add(tf.keras.layers.Dropout(0.3))
123-
model.add(tf.keras.layers.GRU(64))
124-
model.add(tf.keras.layers.Dense(256))
121+
model.add(tf.keras.layers.CuDNNGRU(64, return_sequences=True))
122+
model.add(tf.keras.layers.CuDNNGRU(32))
123+
model.add(tf.keras.layers.Dense(128, activation="relu"))
125124
model.add(tf.keras.layers.Dropout(0.3))
126125
model.add(tf.keras.layers.Dense(n_vocab))
126+
model.add(tf.keras.layers.Dropout(0.3))
127127
model.add(tf.keras.layers.Activation('softmax'))
128128
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
129129

@@ -156,8 +156,8 @@ def generate_notes(self, network_input, pitchnames, n_vocab):
156156
pattern = list(network_input[start])
157157
prediction_output = []
158158

159-
# generate 500 notes
160-
for note_index in range(500):
159+
# generate 200 notes
160+
for note_index in range(200):
161161
prediction_input = np.reshape(pattern, (1, len(pattern), 1))
162162
prediction_input = prediction_input / float(n_vocab)
163163

examples/keras-audio/wandb/settings

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[default]
2-
entity: qualcomm
3-
project: audiogru-mar28
2+
entity: bloomberg-class
3+
project: audio-nov6
44
base_url: https://api.wandb.ai

0 commit comments

Comments
 (0)