Skip to content

Commit b3e0b46

Browse files
committed
Add support for loading/saving remote weights files.
1 parent e354487 commit b3e0b46

File tree

2 files changed

+80
-43
lines changed

2 files changed

+80
-43
lines changed

keras/src/saving/saving_lib.py

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import datetime
44
import io
55
import json
6+
import os
67
import pathlib
8+
import shutil
79
import tempfile
810
import warnings
911
import zipfile
@@ -513,27 +515,38 @@ def save_weights_only(model, filepath, objects_to_skip=None):
513515
514516
Note: only supports h5 for now.
515517
"""
516-
# TODO: if h5 filepath is remote, create the file in a temporary directory
517-
# then upload it
518518
filepath = str(filepath)
519+
tmp_dir = None
520+
remote_filepath = None
519521
if not filepath.endswith(".weights.h5"):
520522
raise ValueError(
521523
"Invalid `filepath` argument: expected a `.weights.h5` extension. "
522524
f"Received: filepath={filepath}"
523525
)
524-
weights_store = H5IOStore(filepath, mode="w")
525-
if objects_to_skip is not None:
526-
visited_saveables = set(id(o) for o in objects_to_skip)
527-
else:
528-
visited_saveables = set()
529-
_save_state(
530-
model,
531-
weights_store=weights_store,
532-
assets_store=None,
533-
inner_path="",
534-
visited_saveables=visited_saveables,
535-
)
536-
weights_store.close()
526+
try:
527+
if file_utils.is_remote_path(filepath):
528+
tmp_dir = get_temp_dir()
529+
local_filepath = os.path.join(tmp_dir, os.path.basename(filepath))
530+
remote_filepath = filepath
531+
filepath = local_filepath
532+
533+
weights_store = H5IOStore(filepath, mode="w")
534+
if objects_to_skip is not None:
535+
visited_saveables = set(id(o) for o in objects_to_skip)
536+
else:
537+
visited_saveables = set()
538+
_save_state(
539+
model,
540+
weights_store=weights_store,
541+
assets_store=None,
542+
inner_path="",
543+
visited_saveables=visited_saveables,
544+
)
545+
weights_store.close()
546+
finally:
547+
if tmp_dir is not None:
548+
file_utils.copy(filepath, remote_filepath)
549+
shutil.rmtree(tmp_dir)
537550

538551

539552
def load_weights_only(
@@ -544,36 +557,47 @@ def load_weights_only(
544557
Note: only supports h5 for now.
545558
"""
546559
archive = None
560+
tmp_dir = None
547561
filepath = str(filepath)
548-
if filepath.endswith(".weights.h5"):
549-
# TODO: download file if h5 filepath is remote
550-
weights_store = H5IOStore(filepath, mode="r")
551-
elif filepath.endswith(".keras"):
552-
archive = zipfile.ZipFile(filepath, "r")
553-
weights_store = H5IOStore(_VARS_FNAME_H5, archive=archive, mode="r")
554-
555-
failed_saveables = set()
556-
if objects_to_skip is not None:
557-
visited_saveables = set(id(o) for o in objects_to_skip)
558-
else:
559-
visited_saveables = set()
560-
error_msgs = {}
561-
_load_state(
562-
model,
563-
weights_store=weights_store,
564-
assets_store=None,
565-
inner_path="",
566-
skip_mismatch=skip_mismatch,
567-
visited_saveables=visited_saveables,
568-
failed_saveables=failed_saveables,
569-
error_msgs=error_msgs,
570-
)
571-
weights_store.close()
572-
if archive:
573-
archive.close()
574562

575-
if failed_saveables:
576-
_raise_loading_failure(error_msgs, warn_only=skip_mismatch)
563+
try:
564+
if file_utils.is_remote_path(filepath):
565+
tmp_dir = get_temp_dir()
566+
local_filepath = os.path.join(tmp_dir, os.path.basename(filepath))
567+
file_utils.copy(filepath, local_filepath)
568+
filepath = local_filepath
569+
570+
if filepath.endswith(".weights.h5"):
571+
weights_store = H5IOStore(filepath, mode="r")
572+
elif filepath.endswith(".keras"):
573+
archive = zipfile.ZipFile(filepath, "r")
574+
weights_store = H5IOStore(_VARS_FNAME_H5, archive=archive, mode="r")
575+
576+
failed_saveables = set()
577+
if objects_to_skip is not None:
578+
visited_saveables = set(id(o) for o in objects_to_skip)
579+
else:
580+
visited_saveables = set()
581+
error_msgs = {}
582+
_load_state(
583+
model,
584+
weights_store=weights_store,
585+
assets_store=None,
586+
inner_path="",
587+
skip_mismatch=skip_mismatch,
588+
visited_saveables=visited_saveables,
589+
failed_saveables=failed_saveables,
590+
error_msgs=error_msgs,
591+
)
592+
weights_store.close()
593+
if archive:
594+
archive.close()
595+
596+
if failed_saveables:
597+
_raise_loading_failure(error_msgs, warn_only=skip_mismatch)
598+
finally:
599+
if tmp_dir is not None:
600+
shutil.rmtree(tmp_dir)
577601

578602

579603
def _raise_loading_failure(error_msgs, warn_only=False):

keras/src/saving/saving_lib_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,3 +1049,16 @@ def test_bidirectional_lstm_saving(self):
10491049
ref_out = model(x)
10501050
out = new_model(x)
10511051
self.assertAllClose(ref_out, out)
1052+
1053+
def test_remove_weights_only_saving_and_loading(self):
1054+
def is_remote_path(path):
1055+
return True
1056+
1057+
temp_filepath = os.path.join(self.get_temp_dir(), "model.weights.h5")
1058+
1059+
with mock.patch(
1060+
"keras.src.utils.file_utils.is_remote_path", is_remote_path
1061+
):
1062+
model = _get_subclassed_model()
1063+
model.save_weights(temp_filepath)
1064+
model.load_weights(temp_filepath)

0 commit comments

Comments
 (0)