Skip to content

Commit a9118d4

Browse files
committed
changes data loading and loss to match video
1 parent ab6131d commit a9118d4

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

videos/time-series/rnn.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,17 @@
2626
config.repeated_predictions = False
2727
config.look_back = 20
2828

29-
df = pd.read_csv('daily-min-temperatures.csv')
30-
data = df.temp.astype('float32').values
29+
def load_data(data_type="airline"):
30+
if data_type == "flu":
31+
df = pd.read_csv('flusearches.csv')
32+
data = df.flu.astype('float32').values
33+
elif data_type == "airline":
34+
df = pd.read_csv('international-airline-passengers.csv')
35+
data = df.passengers.astype('float32').values
36+
elif data_type == "sin":
37+
df = pd.read_csv('sin.csv')
38+
data = df.sin.astype('float32').values
39+
return data
3140

3241
# convert an array of values into a dataset matrix
3342
def create_dataset(dataset):
@@ -39,7 +48,7 @@ def create_dataset(dataset):
3948
return np.array(dataX), np.array(dataY)
4049

4150
data = load_data()
42-
51+
4352
# normalize data to between 0 and 1
4453
max_val = max(data)
4554
min_val = min(data)
@@ -59,7 +68,7 @@ def create_dataset(dataset):
5968
# create and fit the RNN
6069
model = Sequential()
6170
model.add(SimpleRNN(1, input_shape=(config.look_back,1 )))
62-
model.compile(loss='mae', optimizer='adam')
71+
model.compile(loss='mse', optimizer='adam')
6372
model.fit(trainX, trainY, epochs=1000, batch_size=1, validation_data=(testX, testY), callbacks=[WandbCallback(), PlotCallback(trainX, trainY, testX, testY, config.look_back)])
6473

6574

0 commit comments

Comments
 (0)