1
1
import pickle
2
+ from collections import namedtuple
2
3
3
4
import numpy as np
4
5
import pytest
5
6
from absl .testing import parameterized
6
7
7
8
from keras .src import backend
8
9
from keras .src import layers
10
+ from keras .src import losses
9
11
from keras .src import testing
12
+ from keras .src import tree
10
13
from keras .src .layers .core .input_layer import Input
11
14
from keras .src .models .functional import Functional
12
15
from keras .src .models .model import Model
@@ -68,6 +71,48 @@ def _get_model_multi_outputs_dict():
68
71
return model
69
72
70
73
74
+ def _get_model_multi_outputs_struct_list_like (_type ):
75
+ x = Input (shape = (3 ,), name = "x" )
76
+ y1 = layers .Dense (1 , name = "y1" , activation = "sigmoid" )(x )
77
+ y2 = layers .Dense (1 , name = "y2" , activation = "sigmoid" )(x )
78
+ model = Model (x , _type ([y1 , y2 ]))
79
+ return model
80
+
81
+
82
+ def _get_model_multi_outputs_struct_namedtuple ():
83
+ Y = namedtuple ("Y" , ["y1" , "y2" ])
84
+ x = Input (shape = (3 ,), name = "x" )
85
+ y1 = layers .Dense (1 , name = "y1" , activation = "sigmoid" )(x )
86
+ y2 = layers .Dense (1 , name = "y2" , activation = "sigmoid" )(x )
87
+ model = Model (x , Y (y1 , y2 ))
88
+ return model , Y
89
+
90
+
91
+ def _get_model_multi_outputs_struct_dict ():
92
+ x = Input (shape = (3 ,), name = "x" )
93
+ y1 = layers .Dense (1 , name = "y1" , activation = "sigmoid" )(x )
94
+ y2 = layers .Dense (1 , name = "y2" , activation = "sigmoid" )(x )
95
+ model = Model (x , {"a" : y1 , "b" : y2 })
96
+ return model
97
+
98
+
99
+ def _get_model_multi_outputs_struct ():
100
+ x = Input (shape = (3 ,), name = "x" )
101
+ y1 = layers .Dense (1 , name = "y1" , activation = "sigmoid" )(x )
102
+ y2 = layers .Dense (1 , name = "y2" , activation = "sigmoid" )(x )
103
+ y3 = layers .Dense (1 , name = "y3" , activation = "sigmoid" )(x )
104
+ model = Model (
105
+ x ,
106
+ {
107
+ "a" : (y1 , y2 ),
108
+ "b" : {"b1" : y1 , "b2" : y2 },
109
+ "c" : {"c1" : (y1 , y2 ), "c2" : y2 },
110
+ "d" : y3 ,
111
+ },
112
+ )
113
+ return model
114
+
115
+
71
116
def _get_model_multi_outputs_dict_with_single_tensor ():
72
117
x = Input (shape = (3 ,), name = "input_a" )
73
118
output = layers .Dense (1 , name = "output_a" )(x )
@@ -121,6 +166,7 @@ def _get_variable_value_by_path(variables, path):
121
166
122
167
@pytest .mark .requires_trainable_backend
123
168
class ModelTest (testing .TestCase ):
169
+
124
170
def test_functional_rerouting (self ):
125
171
model = _get_model ()
126
172
self .assertIsInstance (model , Functional )
@@ -620,8 +666,7 @@ def test_functional_list_outputs_dict_losses_invalid_keys(self):
620
666
# Fit the model to make sure compile_metrics are built
621
667
with self .assertRaisesRegex (
622
668
KeyError ,
623
- "in the `loss` argument, but they can't be found in the "
624
- "model's output" ,
669
+ "in the `loss` argument, can't be found in the model's output" ,
625
670
):
626
671
model .fit (x , (y1 , y2 ), batch_size = 2 , epochs = 1 , verbose = 0 )
627
672
@@ -638,8 +683,7 @@ def test_functional_list_outputs_dict_losses_no_output_names(self):
638
683
# Fit the model to make sure compile_metrics are built
639
684
with self .assertRaisesRegex (
640
685
KeyError ,
641
- "in the `loss` argument, but they can't be found in the "
642
- "model's output" ,
686
+ "in the `loss` argument, can't be found in the model's output" ,
643
687
):
644
688
model .fit (x , (y1 , y2 ), batch_size = 2 , epochs = 1 , verbose = 0 )
645
689
@@ -683,8 +727,7 @@ def test_functional_dict_outputs_dict_losses_invalid_keys(self):
683
727
# Fit the model to make sure compile_metrics are built
684
728
with self .assertRaisesRegex (
685
729
KeyError ,
686
- "in the `loss` argument, but they can't be found in the "
687
- "model's output" ,
730
+ "in the `loss` argument, can't be found in the model's output" ,
688
731
):
689
732
model .fit (x , (y1 , y2 ), batch_size = 2 , epochs = 1 , verbose = 0 )
690
733
@@ -725,13 +768,10 @@ def test_functional_list_outputs_invalid_nested_list_losses(self):
725
768
["mean_squared_error" , "binary_crossentropy" ],
726
769
],
727
770
)
728
- # Fit the model to make sure compile_metrics are built
729
- with self .assertRaisesRegex (
730
- ValueError ,
731
- "when providing the `loss` argument as a list, "
732
- "it should have as many entries as the model has outputs" ,
733
- ):
734
- model .fit (x , (y1 , y2 ), batch_size = 2 , epochs = 1 , verbose = 0 )
771
+ hist = model .fit (x , (y1 , y2 ), batch_size = 2 , epochs = 1 , verbose = 0 )
772
+ hist_keys = sorted (hist .history .keys ())
773
+ ref_keys = sorted (["loss" , "output_a_loss" , "output_b_loss" ])
774
+ self .assertListEqual (hist_keys , ref_keys )
735
775
736
776
@parameterized .named_parameters (
737
777
("int8" , "int8" ),
@@ -944,3 +984,237 @@ def test_layers_setter(self):
944
984
AttributeError , "`Model.layers` attribute is reserved"
945
985
):
946
986
model .layers = [layers .Dense (4 )]
987
+
988
+ def get_struct_loss (self , structure ):
989
+ def loss_fn (y_true , y_pred ):
990
+ tree .assert_same_structure (structure , y_true , check_types = False )
991
+ tree .assert_same_structure (structure , y_pred , check_types = False )
992
+ tree .map_structure (
993
+ lambda spec , tensor : self .assertEqual (spec .ndim , tensor .ndim ),
994
+ structure ,
995
+ y_true ,
996
+ )
997
+ tree .map_structure (
998
+ lambda spec , tensor : self .assertEqual (spec .ndim , tensor .ndim ),
999
+ structure ,
1000
+ y_pred ,
1001
+ )
1002
+ flat_y_pred , flat_y_true = tree .flatten (y_pred ), tree .flatten (
1003
+ y_true
1004
+ )
1005
+ diff = 0
1006
+ for y_p , y_t in zip (flat_y_pred , flat_y_true ):
1007
+ diff += losses .mean_absolute_error (y_t , y_p )
1008
+ return diff
1009
+
1010
+ return loss_fn
1011
+
1012
+ @parameterized .product (
1013
+ _type = [tuple , list ], other_type = [list , tuple ], weighted = [False , True ]
1014
+ )
1015
+ def test_functional_struct_outputs_struct_losses (
1016
+ self , _type , other_type , weighted
1017
+ ):
1018
+ model = _get_model_multi_outputs_struct_list_like (_type )
1019
+ self .assertIsInstance (model , Functional )
1020
+ x = np .random .rand (8 , 3 )
1021
+ y1 = np .random .rand (8 , 1 )
1022
+ y2 = np .random .rand (8 , 1 )
1023
+ y = _type ([y1 , y2 ])
1024
+ loss = other_type (
1025
+ [
1026
+ self .get_struct_loss (model .output ),
1027
+ _type (
1028
+ [
1029
+ self .get_struct_loss (model .output [0 ]),
1030
+ self .get_struct_loss (model .output [1 ]),
1031
+ ]
1032
+ ),
1033
+ ]
1034
+ )
1035
+ if weighted :
1036
+ loss_weights = tree .map_structure (lambda _ : np .random .rand (), loss )
1037
+ else :
1038
+ loss_weights = None
1039
+
1040
+ model .compile (
1041
+ optimizer = "sgd" ,
1042
+ loss = loss ,
1043
+ loss_weights = loss_weights ,
1044
+ )
1045
+
1046
+ if _type is other_type :
1047
+ with self .assertRaisesRegex (
1048
+ ValueError , "don't have the same structure"
1049
+ ):
1050
+ model .fit (x , y , batch_size = 2 , epochs = 1 , verbose = 0 )
1051
+ else :
1052
+ # Check dict outputs.
1053
+ outputs = model .predict (x )
1054
+ self .assertIsInstance (outputs , _type )
1055
+ # Fit the model to make sure compile_metrics are built
1056
+ hist = model .fit (
1057
+ x ,
1058
+ y ,
1059
+ batch_size = 2 ,
1060
+ epochs = 1 ,
1061
+ verbose = 0 ,
1062
+ )
1063
+ hist_keys = sorted (hist .history .keys ())
1064
+ ref_keys = sorted (
1065
+ [
1066
+ "loss" ,
1067
+ "y1_loss" ,
1068
+ "y2_loss" ,
1069
+ "y1_y2_loss" ,
1070
+ ]
1071
+ )
1072
+ self .assertListEqual (hist_keys , ref_keys )
1073
+
1074
+ @parameterized .named_parameters (("weighted" , True ), ("not_weighted" , False ))
1075
+ def test_functional_struct_outputs_dict_struct_losses (self , weighted ):
1076
+ model = _get_model_multi_outputs_struct_dict ()
1077
+ self .assertIsInstance (model , Functional )
1078
+ x = np .random .rand (8 , 3 )
1079
+ y1 = np .random .rand (8 , 1 )
1080
+ y2 = np .random .rand (8 , 1 )
1081
+
1082
+ y = {"a" : y1 , "b" : y2 }
1083
+ loss = [
1084
+ self .get_struct_loss (model .output ),
1085
+ {
1086
+ "a" : self .get_struct_loss (model .output ["a" ]),
1087
+ "b" : self .get_struct_loss (model .output ["a" ]),
1088
+ },
1089
+ ]
1090
+ if weighted :
1091
+ loss_weights = tree .map_structure (lambda _ : np .random .rand (), loss )
1092
+ else :
1093
+ loss_weights = None
1094
+
1095
+ model .compile (
1096
+ optimizer = "sgd" ,
1097
+ loss = loss ,
1098
+ loss_weights = loss_weights ,
1099
+ )
1100
+ # Check dict outputs.
1101
+ outputs = model .predict (x )
1102
+ self .assertIsInstance (outputs , dict )
1103
+
1104
+ # Fit the model to make sure compile_metrics are built
1105
+ hist = model .fit (
1106
+ x ,
1107
+ y ,
1108
+ batch_size = 2 ,
1109
+ epochs = 1 ,
1110
+ verbose = 0 ,
1111
+ )
1112
+ hist_keys = sorted (hist .history .keys ())
1113
+ ref_keys = sorted (
1114
+ [
1115
+ "loss" ,
1116
+ "a_loss" ,
1117
+ "b_loss" ,
1118
+ "a_b_loss" ,
1119
+ ]
1120
+ )
1121
+ self .assertListEqual (hist_keys , ref_keys )
1122
+
1123
+ def test_functional_struct_outputs_namedtuple_struct_losses (self ):
1124
+ model , Y = _get_model_multi_outputs_struct_namedtuple ()
1125
+ self .assertIsInstance (model , Functional )
1126
+ x = np .random .rand (8 , 3 )
1127
+ y1 = np .random .rand (8 , 1 )
1128
+ y2 = np .random .rand (8 , 1 )
1129
+
1130
+ y = Y (y1 , y2 )
1131
+ model .compile (
1132
+ optimizer = "sgd" ,
1133
+ loss = [
1134
+ self .get_struct_loss (model .output ),
1135
+ Y (
1136
+ self .get_struct_loss (model .output .y1 ),
1137
+ self .get_struct_loss (model .output .y2 ),
1138
+ ),
1139
+ ],
1140
+ )
1141
+ # Check dict outputs.
1142
+ outputs = model .predict (x )
1143
+ self .assertIsInstance (outputs , tuple )
1144
+
1145
+ # Fit the model to make sure compile_metrics are built
1146
+ hist = model .fit (
1147
+ x ,
1148
+ y ,
1149
+ batch_size = 2 ,
1150
+ epochs = 1 ,
1151
+ verbose = 0 ,
1152
+ )
1153
+ hist_keys = sorted (hist .history .keys ())
1154
+ ref_keys = sorted (
1155
+ [
1156
+ "loss" ,
1157
+ "y1_loss" ,
1158
+ "y2_loss" ,
1159
+ "y1_y2_loss" ,
1160
+ ]
1161
+ )
1162
+ self .assertListEqual (hist_keys , ref_keys )
1163
+
1164
+ def test_functional_deeply_nested_outputs_struct_losses (self ):
1165
+ model = _get_model_multi_outputs_struct ()
1166
+ self .assertIsInstance (model , Functional )
1167
+ x = np .random .rand (8 , 3 )
1168
+ y1 = np .random .rand (8 , 1 )
1169
+ y2 = np .random .rand (8 , 1 )
1170
+ y3 = np .random .rand (8 , 1 )
1171
+ y = {
1172
+ "a" : (y1 , y2 ),
1173
+ "b" : {"b1" : y1 , "b2" : y2 },
1174
+ "c" : {"c1" : (y1 , y2 ), "c2" : y2 },
1175
+ "d" : y3 ,
1176
+ }
1177
+ model .compile (
1178
+ optimizer = "sgd" ,
1179
+ loss = {
1180
+ "a" : [
1181
+ self .get_struct_loss (model .output ["a" ]),
1182
+ (None , self .get_struct_loss (model .output ["a" ][1 ])),
1183
+ ],
1184
+ "b" : [
1185
+ self .get_struct_loss (model .output ["b" ]),
1186
+ {"b1" : self .get_struct_loss (model .output ["b" ]["b1" ])},
1187
+ ],
1188
+ "c" : [
1189
+ self .get_struct_loss (model .output ["c" ]),
1190
+ {"c1" : self .get_struct_loss (model .output ["c" ]["c1" ])},
1191
+ ],
1192
+ "d" : self .get_struct_loss (model .output ["d" ]),
1193
+ },
1194
+ )
1195
+ # Check dict outputs.
1196
+ outputs = model .predict (x )
1197
+ self .assertIsInstance (outputs , dict )
1198
+
1199
+ # Fit the model to make sure compile_metrics are built
1200
+ hist = model .fit (
1201
+ x ,
1202
+ y ,
1203
+ batch_size = 2 ,
1204
+ epochs = 1 ,
1205
+ verbose = 0 ,
1206
+ )
1207
+ hist_keys = sorted (hist .history .keys ())
1208
+ ref_keys = sorted (
1209
+ [
1210
+ "a/y2_loss" ,
1211
+ "a_loss" ,
1212
+ "b/b1_loss" ,
1213
+ "b_loss" ,
1214
+ "c/c1_loss" ,
1215
+ "c_loss" ,
1216
+ "d_loss" ,
1217
+ "loss" ,
1218
+ ]
1219
+ )
1220
+ self .assertListEqual (hist_keys , ref_keys )
0 commit comments