Skip to content

Commit f60a3d0

Browse files
authored
Merge pull request tensorflow#4606 from jhseu/branch_134473452
Branch 134473452
2 parents ccae1fd + 300b7bc commit f60a3d0

File tree

154 files changed

+4113
-2826
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

154 files changed

+4113
-2826
lines changed

tensorflow/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ filegroup(
113113
"//tensorflow/contrib/learn:all_files",
114114
"//tensorflow/contrib/learn/python/learn/datasets:all_files",
115115
"//tensorflow/contrib/linear_optimizer:all_files",
116-
"//tensorflow/contrib/linear_optimizer/kernels:all_files",
117116
"//tensorflow/contrib/lookup:all_files",
118117
"//tensorflow/contrib/losses:all_files",
119118
"//tensorflow/contrib/metrics:all_files",

tensorflow/c/c_api.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,10 +1228,10 @@ int TF_OperationGetControlOutputs(TF_Operation* oper,
12281228
return count;
12291229
}
12301230

1231-
TF_Attr_Metadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1232-
const char* attr_name,
1233-
TF_Status* status) {
1234-
TF_Attr_Metadata metadata;
1231+
TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1232+
const char* attr_name,
1233+
TF_Status* status) {
1234+
TF_AttrMetadata metadata;
12351235
const auto* attr = GetAttrValue(oper, attr_name, status);
12361236
if (!status->status.ok()) return metadata;
12371237
switch (attr->value_case()) {

tensorflow/c/c_api.h

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ extern int TF_OperationGetControlOutputs(TF_Operation* oper,
537537
TF_Operation** control_outputs,
538538
int max_control_outputs);
539539

540-
// TF_Attr_Type describes the type of the value of an attribute on an operation.
540+
// TF_AttrType describes the type of the value of an attribute on an operation.
541541
typedef enum {
542542
TF_ATTR_STRING = 0,
543543
TF_ATTR_INT = 1,
@@ -548,9 +548,9 @@ typedef enum {
548548
TF_ATTR_TENSOR = 6,
549549
TF_ATTR_PLACEHOLDER = 7,
550550
TF_ATTR_FUNC = 8,
551-
} TF_Attr_Type;
551+
} TF_AttrType;
552552

553-
// TF_Attr_Metadata describes the value of an attribute on an operation.
553+
// TF_AttrMetadata describes the value of an attribute on an operation.
554554
typedef struct {
555555
// A boolean: 1 if the attribute value is a list, 0 otherwise.
556556
unsigned char is_list;
@@ -560,7 +560,7 @@ typedef struct {
560560

561561
// Type of elements of the list if is_list != 0.
562562
// Type of the single value stored in the attribute if is_list == 0.
563-
TF_Attr_Type type;
563+
TF_AttrType type;
564564

565565
// Total size the attribute value.
566566
// The units of total_size depend on is_list and type.
@@ -579,16 +579,16 @@ typedef struct {
579579
// of dimensions of all shapes in the list.
580580
// (5) Otherwise, total_size is undefined.
581581
int64_t total_size;
582-
} TF_Attr_Metadata;
582+
} TF_AttrMetadata;
583583

584584
// Returns metadata about the value of the attribute `attr_name` of `oper`.
585-
TF_Attr_Metadata TF_OperationGetAttrMetadata(TF_Operation* oper,
586-
const char* attr_name,
587-
TF_Status* status);
585+
TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
586+
const char* attr_name,
587+
TF_Status* status);
588588

589589
// Fills in `value` with the value of the attribute `attr_name`. `value` must
590590
// point to an array of length at least `max_length` (ideally set to
591-
// TF_Attr_Metadata.total_size from TF_OperationGetAttrMetadata(oper,
591+
// TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper,
592592
// attr_name)).
593593
extern void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
594594
void* value, int max_length,
@@ -600,8 +600,8 @@ extern void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
600600
//
601601
// The elements of values will point to addresses in `storage` which must be at
602602
// least `storage_size` bytes large. Ideally, max_values would be set to
603-
// TF_Attr_Metadata.list_size and `storage` would be at least
604-
// TF_Attr_Metadata.total_size, obtained from TF_OperationGetAttrMetadata(oper,
603+
// TF_AttrMetadata.list_size and `storage` would be at least
604+
// TF_AttrMetadata.total_size, obtained from TF_OperationGetAttrMetadata(oper,
605605
// attr_name).
606606
//
607607
// Fails if storage_size is too small to hold the requested number of strings.
@@ -616,7 +616,7 @@ extern void TF_OperationGetAttrInt(TF_Operation* oper, const char* attr_name,
616616

617617
// Fills in `values` with the value of the attribute `attr_name` of `oper`.
618618
// `values` must point to an array of length at least `max_values` (ideally set
619-
// TF_Attr_Metadata.list_size from TF_OperationGetAttrMetadata(oper,
619+
// TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper,
620620
// attr_name)).
621621
extern void TF_OperationGetAttrIntList(TF_Operation* oper,
622622
const char* attr_name, int64_t* values,
@@ -627,7 +627,7 @@ extern void TF_OperationGetAttrFloat(TF_Operation* oper, const char* attr_name,
627627

628628
// Fills in `values` with the value of the attribute `attr_name` of `oper`.
629629
// `values` must point to an array of length at least `max_values` (ideally set
630-
// to TF_Attr_Metadata.list_size from TF_OperationGetAttrMetadata(oper,
630+
// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper,
631631
// attr_name)).
632632
extern void TF_OperationGetAttrFloatList(TF_Operation* oper,
633633
const char* attr_name, float* values,
@@ -638,7 +638,7 @@ extern void TF_OperationGetAttrBool(TF_Operation* oper, const char* attr_name,
638638

639639
// Fills in `values` with the value of the attribute `attr_name` of `oper`.
640640
// `values` must point to an array of length at least `max_values` (ideally set
641-
// to TF_Attr_Metadata.list_size from TF_OperationGetAttrMetadata(oper,
641+
// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper,
642642
// attr_name)).
643643
extern void TF_OperationGetAttrBoolList(TF_Operation* oper,
644644
const char* attr_name,
@@ -650,7 +650,7 @@ extern void TF_OperationGetAttrType(TF_Operation* oper, const char* attr_name,
650650

651651
// Fills in `values` with the value of the attribute `attr_name` of `oper`.
652652
// `values` must point to an array of length at least `max_values` (ideally set
653-
// to TF_Attr_Metadata.list_size from TF_OperationGetAttrMetadata(oper,
653+
// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper,
654654
// attr_name)).
655655
extern void TF_OperationGetAttrTypeList(TF_Operation* oper,
656656
const char* attr_name,
@@ -672,8 +672,8 @@ extern void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
672672
//
673673
// The elements of `dims` will point to addresses in `storage` which must be
674674
// large enough to hold at least `storage_size` int64_ts. Ideally, `num_shapes`
675-
// would be set to TF_Attr_Metadata.list_size and `storage_size` would be set to
676-
// TF_Attr_Metadata.total_size from TF_OperationGetAttrMetadata(oper,
675+
// would be set to TF_AttrMetadata.list_size and `storage_size` would be set to
676+
// TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper,
677677
// attr_name).
678678
//
679679
// Fails if storage_size is insufficient to hold the requested shapes.
@@ -692,7 +692,7 @@ extern void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
692692

693693
// Fills in `values` with binary-serialized TensorShapeProto values of the
694694
// attribute `attr_name` of `oper`. `values` must point to an array of length at
695-
// least `num_values` (ideally set to TF_Attr_Metadata.list_size from
695+
// least `num_values` (ideally set to TF_AttrMetadata.list_size from
696696
// TF_OperationGetAttrMetadata(oper, attr_name)).
697697
extern void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
698698
const char* attr_name,
@@ -709,7 +709,7 @@ extern void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
709709

710710
// Fills in `values` with the TF_Tensor values of the attribute `attr_name` of
711711
// `oper`. `values` must point to an array of TF_Tensor* of length at least
712-
// `max_values` (ideally set to TF_Attr_Metadata.list_size from
712+
// `max_values` (ideally set to TF_AttrMetadata.list_size from
713713
// TF_OperationGetAttrMetadata(oper, attr_name)).
714714
//
715715
// The caller takes ownership of all the non-null TF_Tensor* entries in `values`

tensorflow/c/c_api_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ TEST(CAPI, ColocateWith) {
864864
TF_Operation* add = TF_FinishOperation(desc, s);
865865
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
866866

867-
TF_Attr_Metadata m =
867+
TF_AttrMetadata m =
868868
TF_OperationGetAttrMetadata(add, tensorflow::kColocationAttrName, s);
869869
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
870870
EXPECT_EQ(1, m.is_list);

tensorflow/cc/saved_model/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ filegroup(
5252
srcs = glob([
5353
"testdata/half_plus_two/**",
5454
"testdata/half_plus_two_pbtxt/**",
55+
"testdata/half_plus_two_sharded/**",
5556
]),
5657
)
5758

tensorflow/cc/saved_model/constants.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ constexpr char kSavedModelVariablesDirectory[] = "variables";
3030
// SavedModel variables filename.
3131
constexpr char kSavedModelVariablesFilename[] = "saved_model_variables";
3232

33+
// SavedModel sharded variables filename.
34+
constexpr char kSavedModelVariablesShardedFilename[] =
35+
"saved_model_variables-\?\?\?\?\?-of-\?\?\?\?\?";
36+
3337
// Commonly used tags.
3438
constexpr char kSavedModelTagServe[] = "serve";
3539
constexpr char kSavedModelTagTrain[] = "train";

tensorflow/cc/saved_model/loader.cc

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,18 @@ Status Restore(const RunOptions& run_options, const string& export_dir,
8686
const StringPiece restore_op_name,
8787
const StringPiece variable_filename_const_op_name,
8888
Session* session) {
89-
const string variables_path = io::JoinPath(
90-
export_dir, kSavedModelVariablesDirectory, kSavedModelVariablesFilename);
91-
if (!Env::Default()->FileExists(variables_path)) {
92-
return Status(
93-
error::Code::NOT_FOUND,
94-
"Could not find checkpointed variables at: " + variables_path);
89+
// Find path to variables to be restored in export directory.
90+
string variables_path =
91+
io::JoinPath(export_dir, kSavedModelVariablesDirectory);
92+
const string unsharded_variables_path =
93+
io::JoinPath(variables_path, kSavedModelVariablesFilename);
94+
if (Env::Default()->FileExists(unsharded_variables_path)) {
95+
variables_path = unsharded_variables_path;
96+
} else {
97+
const string sharded_variables_path =
98+
io::JoinPath(variables_path, kSavedModelVariablesShardedFilename);
99+
variables_path = sharded_variables_path;
95100
}
96-
97101
// Add variables to the graph.
98102
Tensor variables_path_tensor(DT_STRING, TensorShape({}));
99103
variables_path_tensor.scalar<string>()() = variables_path;

tensorflow/cc/saved_model/loader_test.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ namespace {
2828

2929
constexpr char kTestDataPb[] = "cc/saved_model/testdata/half_plus_two";
3030
constexpr char kTestDataPbTxt[] = "cc/saved_model/testdata/half_plus_two_pbtxt";
31+
constexpr char kTestDataSharded[] =
32+
"cc/saved_model/testdata/half_plus_two_sharded";
3133

3234
class LoaderTest : public ::testing::Test {
3335
protected:
@@ -110,6 +112,18 @@ TEST_F(LoaderTest, PbtxtFormat) {
110112
CheckSavedModelBundle(bundle);
111113
}
112114

115+
TEST_F(LoaderTest, ShardedVariables) {
116+
SavedModelBundle bundle;
117+
SessionOptions session_options;
118+
RunOptions run_options;
119+
120+
const string export_dir =
121+
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
122+
TF_ASSERT_OK(LoadSavedModel(export_dir, {kSavedModelTagServe},
123+
session_options, run_options, &bundle));
124+
CheckSavedModelBundle(bundle);
125+
}
126+
113127
TEST_F(LoaderTest, InvalidExportPath) {
114128
SavedModelBundle bundle;
115129
RunOptions run_options;
Binary file not shown.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
model_checkpoint_path: "/tmp/saved_model/half_plus_two/variables/saved_model_variables-?????-of-00001"
2+
all_model_checkpoint_paths: "/tmp/saved_model/half_plus_two/variables/saved_model_variables-?????-of-00001"

0 commit comments

Comments
 (0)