Skip to content

Commit d94e91f

Browse files
committed
Fix warnings
1 parent 525c65b commit d94e91f

File tree

3 files changed

+20
-15
lines changed

3 files changed

+20
-15
lines changed

keras/src/callbacks/model_checkpoint.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,10 @@ def _should_save_model(self, epoch, batch, logs, filepath):
248248
current = logs.get(self.monitor)
249249
if current is None:
250250
warnings.warn(
251-
f"Can save best model only with {self.monitor} "
252-
"available, skipping.",
251+
f"Can save best model only with {self.monitor} available.",
253252
stacklevel=2,
254253
)
255-
return False
254+
return True
256255
elif (
257256
isinstance(current, np.ndarray) or backend.is_tensor(current)
258257
) and len(current.shape) > 0:

keras/src/callbacks/model_checkpoint_test.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -167,17 +167,17 @@ def get_model():
167167
filepath, monitor="unknown", save_best_only=True
168168
)
169169
]
170-
model.fit(
171-
x_train,
172-
y_train,
173-
batch_size=BATCH_SIZE,
174-
validation_data=(x_test, y_test),
175-
callbacks=cbks,
176-
epochs=1,
177-
verbose=0,
178-
)
179-
# File won't be written.
180-
self.assertFalse(os.path.exists(filepath))
170+
with pytest.warns(UserWarning):
171+
model.fit(
172+
x_train,
173+
y_train,
174+
batch_size=BATCH_SIZE,
175+
validation_data=(x_test, y_test),
176+
callbacks=cbks,
177+
epochs=1,
178+
verbose=0,
179+
)
180+
self.assertTrue(os.path.exists(filepath))
181181

182182
# Case 6
183183
with warnings.catch_warnings(record=True) as warning_logs:

keras/src/models/functional.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,13 @@ def _adjust_input_rank(self, flat_inputs):
288288

289289
def _standardize_inputs(self, inputs):
290290
raise_exception = False
291-
if isinstance(inputs, dict) and not isinstance(
291+
if (
292+
isinstance(self._inputs_struct, list)
293+
and len(self._inputs_struct) == 1
294+
and ops.is_tensor(inputs)
295+
):
296+
inputs = [inputs]
297+
elif isinstance(inputs, dict) and not isinstance(
292298
self._inputs_struct, dict
293299
):
294300
# This is to avoid warning

0 commit comments

Comments
 (0)