Skip to content

Commit 4ff65a5

Browse files
authored
Enable Structured Losses (keras-team#20358)
* Enable structured losses * fix torch scalar broadcast * format * path-like -> struct-like * revert removal of rank check * `type` -> `_type` `if` -> `elif` * `anti_type` -> `other_type` * removed naming truncation
1 parent 2cddf9e commit 4ff65a5

File tree

5 files changed

+676
-194
lines changed

5 files changed

+676
-194
lines changed

keras/src/losses/losses.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from keras.src import backend
44
from keras.src import ops
5+
from keras.src import tree
56
from keras.src.api_export import keras_export
67
from keras.src.losses.loss import Loss
78
from keras.src.losses.loss import squeeze_or_expand_to_same_rank
@@ -23,7 +24,11 @@ def __init__(
2324
self._fn_kwargs = kwargs
2425

2526
def call(self, y_true, y_pred):
26-
y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
27+
y_true_y_pred = tree.map_structure(
28+
squeeze_or_expand_to_same_rank, y_true, y_pred
29+
)
30+
y_true = tree.map_structure_up_to(y_true, lambda x: x[0], y_true_y_pred)
31+
y_pred = tree.map_structure_up_to(y_pred, lambda x: x[1], y_true_y_pred)
2732
return self.fn(y_true, y_pred, **self._fn_kwargs)
2833

2934
def get_config(self):

keras/src/models/model_test.py

Lines changed: 287 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import pickle
2+
from collections import namedtuple
23

34
import numpy as np
45
import pytest
56
from absl.testing import parameterized
67

78
from keras.src import backend
89
from keras.src import layers
10+
from keras.src import losses
911
from keras.src import testing
12+
from keras.src import tree
1013
from keras.src.layers.core.input_layer import Input
1114
from keras.src.models.functional import Functional
1215
from keras.src.models.model import Model
@@ -68,6 +71,48 @@ def _get_model_multi_outputs_dict():
6871
return model
6972

7073

74+
def _get_model_multi_outputs_struct_list_like(_type):
75+
x = Input(shape=(3,), name="x")
76+
y1 = layers.Dense(1, name="y1", activation="sigmoid")(x)
77+
y2 = layers.Dense(1, name="y2", activation="sigmoid")(x)
78+
model = Model(x, _type([y1, y2]))
79+
return model
80+
81+
82+
def _get_model_multi_outputs_struct_namedtuple():
83+
Y = namedtuple("Y", ["y1", "y2"])
84+
x = Input(shape=(3,), name="x")
85+
y1 = layers.Dense(1, name="y1", activation="sigmoid")(x)
86+
y2 = layers.Dense(1, name="y2", activation="sigmoid")(x)
87+
model = Model(x, Y(y1, y2))
88+
return model, Y
89+
90+
91+
def _get_model_multi_outputs_struct_dict():
92+
x = Input(shape=(3,), name="x")
93+
y1 = layers.Dense(1, name="y1", activation="sigmoid")(x)
94+
y2 = layers.Dense(1, name="y2", activation="sigmoid")(x)
95+
model = Model(x, {"a": y1, "b": y2})
96+
return model
97+
98+
99+
def _get_model_multi_outputs_struct():
100+
x = Input(shape=(3,), name="x")
101+
y1 = layers.Dense(1, name="y1", activation="sigmoid")(x)
102+
y2 = layers.Dense(1, name="y2", activation="sigmoid")(x)
103+
y3 = layers.Dense(1, name="y3", activation="sigmoid")(x)
104+
model = Model(
105+
x,
106+
{
107+
"a": (y1, y2),
108+
"b": {"b1": y1, "b2": y2},
109+
"c": {"c1": (y1, y2), "c2": y2},
110+
"d": y3,
111+
},
112+
)
113+
return model
114+
115+
71116
def _get_model_multi_outputs_dict_with_single_tensor():
72117
x = Input(shape=(3,), name="input_a")
73118
output = layers.Dense(1, name="output_a")(x)
@@ -121,6 +166,7 @@ def _get_variable_value_by_path(variables, path):
121166

122167
@pytest.mark.requires_trainable_backend
123168
class ModelTest(testing.TestCase):
169+
124170
def test_functional_rerouting(self):
125171
model = _get_model()
126172
self.assertIsInstance(model, Functional)
@@ -620,8 +666,7 @@ def test_functional_list_outputs_dict_losses_invalid_keys(self):
620666
# Fit the model to make sure compile_metrics are built
621667
with self.assertRaisesRegex(
622668
KeyError,
623-
"in the `loss` argument, but they can't be found in the "
624-
"model's output",
669+
"in the `loss` argument, can't be found in the model's output",
625670
):
626671
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
627672

@@ -638,8 +683,7 @@ def test_functional_list_outputs_dict_losses_no_output_names(self):
638683
# Fit the model to make sure compile_metrics are built
639684
with self.assertRaisesRegex(
640685
KeyError,
641-
"in the `loss` argument, but they can't be found in the "
642-
"model's output",
686+
"in the `loss` argument, can't be found in the model's output",
643687
):
644688
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
645689

@@ -683,8 +727,7 @@ def test_functional_dict_outputs_dict_losses_invalid_keys(self):
683727
# Fit the model to make sure compile_metrics are built
684728
with self.assertRaisesRegex(
685729
KeyError,
686-
"in the `loss` argument, but they can't be found in the "
687-
"model's output",
730+
"in the `loss` argument, can't be found in the model's output",
688731
):
689732
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
690733

@@ -725,13 +768,10 @@ def test_functional_list_outputs_invalid_nested_list_losses(self):
725768
["mean_squared_error", "binary_crossentropy"],
726769
],
727770
)
728-
# Fit the model to make sure compile_metrics are built
729-
with self.assertRaisesRegex(
730-
ValueError,
731-
"when providing the `loss` argument as a list, "
732-
"it should have as many entries as the model has outputs",
733-
):
734-
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
771+
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
772+
hist_keys = sorted(hist.history.keys())
773+
ref_keys = sorted(["loss", "output_a_loss", "output_b_loss"])
774+
self.assertListEqual(hist_keys, ref_keys)
735775

736776
@parameterized.named_parameters(
737777
("int8", "int8"),
@@ -944,3 +984,237 @@ def test_layers_setter(self):
944984
AttributeError, "`Model.layers` attribute is reserved"
945985
):
946986
model.layers = [layers.Dense(4)]
987+
988+
def get_struct_loss(self, structure):
989+
def loss_fn(y_true, y_pred):
990+
tree.assert_same_structure(structure, y_true, check_types=False)
991+
tree.assert_same_structure(structure, y_pred, check_types=False)
992+
tree.map_structure(
993+
lambda spec, tensor: self.assertEqual(spec.ndim, tensor.ndim),
994+
structure,
995+
y_true,
996+
)
997+
tree.map_structure(
998+
lambda spec, tensor: self.assertEqual(spec.ndim, tensor.ndim),
999+
structure,
1000+
y_pred,
1001+
)
1002+
flat_y_pred, flat_y_true = tree.flatten(y_pred), tree.flatten(
1003+
y_true
1004+
)
1005+
diff = 0
1006+
for y_p, y_t in zip(flat_y_pred, flat_y_true):
1007+
diff += losses.mean_absolute_error(y_t, y_p)
1008+
return diff
1009+
1010+
return loss_fn
1011+
1012+
@parameterized.product(
1013+
_type=[tuple, list], other_type=[list, tuple], weighted=[False, True]
1014+
)
1015+
def test_functional_struct_outputs_struct_losses(
1016+
self, _type, other_type, weighted
1017+
):
1018+
model = _get_model_multi_outputs_struct_list_like(_type)
1019+
self.assertIsInstance(model, Functional)
1020+
x = np.random.rand(8, 3)
1021+
y1 = np.random.rand(8, 1)
1022+
y2 = np.random.rand(8, 1)
1023+
y = _type([y1, y2])
1024+
loss = other_type(
1025+
[
1026+
self.get_struct_loss(model.output),
1027+
_type(
1028+
[
1029+
self.get_struct_loss(model.output[0]),
1030+
self.get_struct_loss(model.output[1]),
1031+
]
1032+
),
1033+
]
1034+
)
1035+
if weighted:
1036+
loss_weights = tree.map_structure(lambda _: np.random.rand(), loss)
1037+
else:
1038+
loss_weights = None
1039+
1040+
model.compile(
1041+
optimizer="sgd",
1042+
loss=loss,
1043+
loss_weights=loss_weights,
1044+
)
1045+
1046+
if _type is other_type:
1047+
with self.assertRaisesRegex(
1048+
ValueError, "don't have the same structure"
1049+
):
1050+
model.fit(x, y, batch_size=2, epochs=1, verbose=0)
1051+
else:
1052+
# Check dict outputs.
1053+
outputs = model.predict(x)
1054+
self.assertIsInstance(outputs, _type)
1055+
# Fit the model to make sure compile_metrics are built
1056+
hist = model.fit(
1057+
x,
1058+
y,
1059+
batch_size=2,
1060+
epochs=1,
1061+
verbose=0,
1062+
)
1063+
hist_keys = sorted(hist.history.keys())
1064+
ref_keys = sorted(
1065+
[
1066+
"loss",
1067+
"y1_loss",
1068+
"y2_loss",
1069+
"y1_y2_loss",
1070+
]
1071+
)
1072+
self.assertListEqual(hist_keys, ref_keys)
1073+
1074+
@parameterized.named_parameters(("weighted", True), ("not_weighted", False))
1075+
def test_functional_struct_outputs_dict_struct_losses(self, weighted):
1076+
model = _get_model_multi_outputs_struct_dict()
1077+
self.assertIsInstance(model, Functional)
1078+
x = np.random.rand(8, 3)
1079+
y1 = np.random.rand(8, 1)
1080+
y2 = np.random.rand(8, 1)
1081+
1082+
y = {"a": y1, "b": y2}
1083+
loss = [
1084+
self.get_struct_loss(model.output),
1085+
{
1086+
"a": self.get_struct_loss(model.output["a"]),
1087+
"b": self.get_struct_loss(model.output["a"]),
1088+
},
1089+
]
1090+
if weighted:
1091+
loss_weights = tree.map_structure(lambda _: np.random.rand(), loss)
1092+
else:
1093+
loss_weights = None
1094+
1095+
model.compile(
1096+
optimizer="sgd",
1097+
loss=loss,
1098+
loss_weights=loss_weights,
1099+
)
1100+
# Check dict outputs.
1101+
outputs = model.predict(x)
1102+
self.assertIsInstance(outputs, dict)
1103+
1104+
# Fit the model to make sure compile_metrics are built
1105+
hist = model.fit(
1106+
x,
1107+
y,
1108+
batch_size=2,
1109+
epochs=1,
1110+
verbose=0,
1111+
)
1112+
hist_keys = sorted(hist.history.keys())
1113+
ref_keys = sorted(
1114+
[
1115+
"loss",
1116+
"a_loss",
1117+
"b_loss",
1118+
"a_b_loss",
1119+
]
1120+
)
1121+
self.assertListEqual(hist_keys, ref_keys)
1122+
1123+
def test_functional_struct_outputs_namedtuple_struct_losses(self):
1124+
model, Y = _get_model_multi_outputs_struct_namedtuple()
1125+
self.assertIsInstance(model, Functional)
1126+
x = np.random.rand(8, 3)
1127+
y1 = np.random.rand(8, 1)
1128+
y2 = np.random.rand(8, 1)
1129+
1130+
y = Y(y1, y2)
1131+
model.compile(
1132+
optimizer="sgd",
1133+
loss=[
1134+
self.get_struct_loss(model.output),
1135+
Y(
1136+
self.get_struct_loss(model.output.y1),
1137+
self.get_struct_loss(model.output.y2),
1138+
),
1139+
],
1140+
)
1141+
# Check dict outputs.
1142+
outputs = model.predict(x)
1143+
self.assertIsInstance(outputs, tuple)
1144+
1145+
# Fit the model to make sure compile_metrics are built
1146+
hist = model.fit(
1147+
x,
1148+
y,
1149+
batch_size=2,
1150+
epochs=1,
1151+
verbose=0,
1152+
)
1153+
hist_keys = sorted(hist.history.keys())
1154+
ref_keys = sorted(
1155+
[
1156+
"loss",
1157+
"y1_loss",
1158+
"y2_loss",
1159+
"y1_y2_loss",
1160+
]
1161+
)
1162+
self.assertListEqual(hist_keys, ref_keys)
1163+
1164+
def test_functional_deeply_nested_outputs_struct_losses(self):
1165+
model = _get_model_multi_outputs_struct()
1166+
self.assertIsInstance(model, Functional)
1167+
x = np.random.rand(8, 3)
1168+
y1 = np.random.rand(8, 1)
1169+
y2 = np.random.rand(8, 1)
1170+
y3 = np.random.rand(8, 1)
1171+
y = {
1172+
"a": (y1, y2),
1173+
"b": {"b1": y1, "b2": y2},
1174+
"c": {"c1": (y1, y2), "c2": y2},
1175+
"d": y3,
1176+
}
1177+
model.compile(
1178+
optimizer="sgd",
1179+
loss={
1180+
"a": [
1181+
self.get_struct_loss(model.output["a"]),
1182+
(None, self.get_struct_loss(model.output["a"][1])),
1183+
],
1184+
"b": [
1185+
self.get_struct_loss(model.output["b"]),
1186+
{"b1": self.get_struct_loss(model.output["b"]["b1"])},
1187+
],
1188+
"c": [
1189+
self.get_struct_loss(model.output["c"]),
1190+
{"c1": self.get_struct_loss(model.output["c"]["c1"])},
1191+
],
1192+
"d": self.get_struct_loss(model.output["d"]),
1193+
},
1194+
)
1195+
# Check dict outputs.
1196+
outputs = model.predict(x)
1197+
self.assertIsInstance(outputs, dict)
1198+
1199+
# Fit the model to make sure compile_metrics are built
1200+
hist = model.fit(
1201+
x,
1202+
y,
1203+
batch_size=2,
1204+
epochs=1,
1205+
verbose=0,
1206+
)
1207+
hist_keys = sorted(hist.history.keys())
1208+
ref_keys = sorted(
1209+
[
1210+
"a/y2_loss",
1211+
"a_loss",
1212+
"b/b1_loss",
1213+
"b_loss",
1214+
"c/c1_loss",
1215+
"c_loss",
1216+
"d_loss",
1217+
"loss",
1218+
]
1219+
)
1220+
self.assertListEqual(hist_keys, ref_keys)

0 commit comments

Comments
 (0)