Skip to content

Commit f573a86

Browse files
kudkudakfchollet
authored andcommitted
1 parent 0e18cb3 commit f573a86

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

keras/engine/topology.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -927,15 +927,21 @@ def add_update(self, updates, inputs=None):
927927
def get_updates_for(self, inputs):
928928
if not hasattr(self, '_per_input_updates'):
929929
return []
930-
inputs_hash = object_list_uid(inputs)
930+
if inputs is not None:
931+
inputs_hash = object_list_uid(inputs)
932+
else:
933+
inputs_hash = None
931934
if inputs_hash in self._per_input_updates:
932935
return self._per_input_updates[inputs_hash]
933936
return []
934937

935938
def get_losses_for(self, inputs):
936939
if not hasattr(self, '_per_input_losses'):
937940
return []
938-
inputs_hash = object_list_uid(inputs)
941+
if inputs is not None:
942+
inputs_hash = object_list_uid(inputs)
943+
else:
944+
inputs_hash = None
939945
if inputs_hash in self._per_input_losses:
940946
return self._per_input_losses[inputs_hash]
941947
return []

tests/keras/engine/test_topology.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,27 @@
99
from keras.models import model_from_json, model_from_yaml
1010
from keras.utils.test_utils import keras_test
1111

12+
@keras_test
13+
def test_get_updates_for():
14+
a = Input(shape=(2,))
15+
dense_layer = Dense(1)
16+
dense_layer.add_update(0, inputs=a)
17+
dense_layer.add_update(1, inputs=None)
18+
19+
assert dense_layer.get_updates_for(a) == [0]
20+
assert dense_layer.get_updates_for(None) == [1]
21+
22+
23+
@keras_test
24+
def test_get_losses_for():
25+
a = Input(shape=(2,))
26+
dense_layer = Dense(1)
27+
dense_layer.add_loss(0, inputs=a)
28+
dense_layer.add_loss(1, inputs=None)
29+
30+
assert dense_layer.get_losses_for(a) == [0]
31+
assert dense_layer.get_losses_for(None) == [1]
32+
1233

1334
@keras_test
1435
def test_trainable_weights():

0 commit comments

Comments
 (0)