3
3
import datetime
4
4
import io
5
5
import json
6
+ import os
6
7
import pathlib
8
+ import shutil
7
9
import tempfile
8
10
import warnings
9
11
import zipfile
@@ -513,27 +515,38 @@ def save_weights_only(model, filepath, objects_to_skip=None):
513
515
514
516
Note: only supports h5 for now.
515
517
"""
516
- # TODO: if h5 filepath is remote, create the file in a temporary directory
517
- # then upload it
518
518
filepath = str (filepath )
519
+ tmp_dir = None
520
+ remote_filepath = None
519
521
if not filepath .endswith (".weights.h5" ):
520
522
raise ValueError (
521
523
"Invalid `filepath` argument: expected a `.weights.h5` extension. "
522
524
f"Received: filepath={ filepath } "
523
525
)
524
- weights_store = H5IOStore (filepath , mode = "w" )
525
- if objects_to_skip is not None :
526
- visited_saveables = set (id (o ) for o in objects_to_skip )
527
- else :
528
- visited_saveables = set ()
529
- _save_state (
530
- model ,
531
- weights_store = weights_store ,
532
- assets_store = None ,
533
- inner_path = "" ,
534
- visited_saveables = visited_saveables ,
535
- )
536
- weights_store .close ()
526
+ try :
527
+ if file_utils .is_remote_path (filepath ):
528
+ tmp_dir = get_temp_dir ()
529
+ local_filepath = os .path .join (tmp_dir , os .path .basename (filepath ))
530
+ remote_filepath = filepath
531
+ filepath = local_filepath
532
+
533
+ weights_store = H5IOStore (filepath , mode = "w" )
534
+ if objects_to_skip is not None :
535
+ visited_saveables = set (id (o ) for o in objects_to_skip )
536
+ else :
537
+ visited_saveables = set ()
538
+ _save_state (
539
+ model ,
540
+ weights_store = weights_store ,
541
+ assets_store = None ,
542
+ inner_path = "" ,
543
+ visited_saveables = visited_saveables ,
544
+ )
545
+ weights_store .close ()
546
+ finally :
547
+ if tmp_dir is not None :
548
+ file_utils .copy (filepath , remote_filepath )
549
+ shutil .rmtree (tmp_dir )
537
550
538
551
539
552
def load_weights_only (
@@ -544,36 +557,47 @@ def load_weights_only(
544
557
Note: only supports h5 for now.
545
558
"""
546
559
archive = None
560
+ tmp_dir = None
547
561
filepath = str (filepath )
548
- if filepath .endswith (".weights.h5" ):
549
- # TODO: download file if h5 filepath is remote
550
- weights_store = H5IOStore (filepath , mode = "r" )
551
- elif filepath .endswith (".keras" ):
552
- archive = zipfile .ZipFile (filepath , "r" )
553
- weights_store = H5IOStore (_VARS_FNAME_H5 , archive = archive , mode = "r" )
554
-
555
- failed_saveables = set ()
556
- if objects_to_skip is not None :
557
- visited_saveables = set (id (o ) for o in objects_to_skip )
558
- else :
559
- visited_saveables = set ()
560
- error_msgs = {}
561
- _load_state (
562
- model ,
563
- weights_store = weights_store ,
564
- assets_store = None ,
565
- inner_path = "" ,
566
- skip_mismatch = skip_mismatch ,
567
- visited_saveables = visited_saveables ,
568
- failed_saveables = failed_saveables ,
569
- error_msgs = error_msgs ,
570
- )
571
- weights_store .close ()
572
- if archive :
573
- archive .close ()
574
562
575
- if failed_saveables :
576
- _raise_loading_failure (error_msgs , warn_only = skip_mismatch )
563
+ try :
564
+ if file_utils .is_remote_path (filepath ):
565
+ tmp_dir = get_temp_dir ()
566
+ local_filepath = os .path .join (tmp_dir , os .path .basename (filepath ))
567
+ file_utils .copy (filepath , local_filepath )
568
+ filepath = local_filepath
569
+
570
+ if filepath .endswith (".weights.h5" ):
571
+ weights_store = H5IOStore (filepath , mode = "r" )
572
+ elif filepath .endswith (".keras" ):
573
+ archive = zipfile .ZipFile (filepath , "r" )
574
+ weights_store = H5IOStore (_VARS_FNAME_H5 , archive = archive , mode = "r" )
575
+
576
+ failed_saveables = set ()
577
+ if objects_to_skip is not None :
578
+ visited_saveables = set (id (o ) for o in objects_to_skip )
579
+ else :
580
+ visited_saveables = set ()
581
+ error_msgs = {}
582
+ _load_state (
583
+ model ,
584
+ weights_store = weights_store ,
585
+ assets_store = None ,
586
+ inner_path = "" ,
587
+ skip_mismatch = skip_mismatch ,
588
+ visited_saveables = visited_saveables ,
589
+ failed_saveables = failed_saveables ,
590
+ error_msgs = error_msgs ,
591
+ )
592
+ weights_store .close ()
593
+ if archive :
594
+ archive .close ()
595
+
596
+ if failed_saveables :
597
+ _raise_loading_failure (error_msgs , warn_only = skip_mismatch )
598
+ finally :
599
+ if tmp_dir is not None :
600
+ shutil .rmtree (tmp_dir )
577
601
578
602
579
603
def _raise_loading_failure (error_msgs , warn_only = False ):
0 commit comments