Skip to content

Commit 8d7f4a1

Browse files
authored
Fix Mnist example running in two browsers at the same time (streamlit#216)
* update mnist demo to use TF 2.0 api * upgrade tensorflow to 2.0.0
1 parent bf8e883 commit 8d7f4a1

File tree

4 files changed

+311
-185
lines changed

4 files changed

+311
-185
lines changed

examples/mnist-cnn.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,27 @@
2424
import streamlit as st
2525
from streamlit import config
2626

27-
from keras.datasets import mnist
28-
from keras.layers import Conv2D, MaxPooling2D, Dropout, Dense, Flatten
29-
from keras.models import Sequential
30-
from keras.optimizers import SGD
27+
from tensorflow.keras.datasets import mnist
28+
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dropout, Dense, Flatten
29+
from tensorflow.keras.models import Sequential
30+
from tensorflow.keras.optimizers import SGD
3131
from keras.utils import np_utils
32-
import keras
32+
from tensorflow import keras
3333
import math
3434
import numpy as np
3535
import pandas as pd
3636
import time
3737

38-
# https://kobkrit.com/using-allow-growth-memory-option-in-tensorflow-and-keras-dc8c8081bc96
39-
from keras.backend.tensorflow_backend import set_session
4038
import tensorflow as tf
4139

42-
tf_config = tf.ConfigProto()
4340
# dynamically grow the memory used on the GPU
4441
# this option is fine on non gpus as well.
42+
tf_config = tf.compat.v1.ConfigProto()
4543
tf_config.gpu_options.allow_growth = True
4644
tf_config.log_device_placement = True
47-
set_session(tf.Session(config=tf_config))
45+
46+
# https://kobkrit.com/using-allow-growth-memory-option-in-tensorflow-and-keras-dc8c8081bc96
47+
tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=tf_config))
4848

4949

5050
class MyCallback(keras.callbacks.Callback):
@@ -67,10 +67,12 @@ def on_epoch_begin(self, epoch, logs=None):
6767

6868
def on_batch_end(self, batch, logs=None):
6969
rows = pd.DataFrame(
70-
[[logs["loss"], logs["accuracy"]]], columns=["loss", "accuracy"])
70+
[[logs["loss"], logs["accuracy"]]], columns=["loss", "accuracy"]
71+
)
7172
if batch % 10 == 0:
72-
self._epoch_chart.add_rows({"loss": [logs["loss"]],
73-
"accuracy": [logs["accuracy"]]})
73+
self._epoch_chart.add_rows(
74+
{"loss": [logs["loss"]], "accuracy": [logs["accuracy"]]}
75+
)
7476
if batch % 100 == 99:
7577
self._summary_chart.add_rows(rows)
7678
percent_complete = logs["batch"] * logs["size"] / self.params["samples"]
@@ -96,6 +98,7 @@ def on_epoch_end(self, epoch, logs=None):
9698
% {"epoch": epoch, "summary": summary}
9799
)
98100

101+
99102
st.title("MNIST CNN")
100103

101104
(x_train, y_train), (x_test, y_test) = mnist.load_data()
@@ -134,8 +137,7 @@ def on_epoch_end(self, epoch, logs=None):
134137
model.add(Dense(8, activation="relu"))
135138
model.add(Dense(num_classes, activation="softmax"))
136139

137-
model.compile(
138-
loss="categorical_crossentropy", optimizer=sgd, metrics=["accuracy"])
140+
model.compile(loss="categorical_crossentropy", optimizer=sgd, metrics=["accuracy"])
139141

140142
show_terminal_output = not config.get_option("server.liveSave")
141143
model.fit(

lib/Pipfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ bokeh = "*"
2525
graphviz = "*"
2626
parameterized = "*"
2727
pydot = "*"
28-
tensorflow = "*"
28+
tensorflow = ">=2.0.0"
2929
seaborn = "*"
3030
prometheus-client = "*"
3131
opencv-python = "*"

0 commit comments

Comments
 (0)