Skip to content

Commit a11ef39

Browse files
authored
Proposed fix for issue keras-team#21519: Reshape layer does not handle -1 shape… (keras-team#21568)
* Proposed fix for issue keras-team#21519: Reshape layer does not handle -1 shape infor dynamically. * Removed unused part of the code. The code in the fix should deal exclusively with the case that is not properly handled by ops.reshape. This is when the batch dimension does not have a static shape (shape == None) but all other dimensions have static shape. In that case we can determine the static shape of the dimension containing -1 form all dimensions that are not the batch dimension. * applied changes according to pre-commit hook * Added reshape_test test case that fails with original implementation and succeeds with fix. * Added asserts for expected result. * Fixed test name and added a further test with custom Model and Conv1D layer. * applied changes according to pre-commit hook * Fixed test to use a custom model and not a custom layer as indicated in the test name. * Implemented suggested changes: - Use original implementation from build method to avoid repeating the implementation in compute_reshape_output_shape. - Set the self.built attribute to True in __init__ because no further build is required. Additional change: - Mention the optional use of -1 for a single dimension of the target_shape. * Implemented suggested changes: - removed explicit call to layer.build from the tests, - changed the new test to use Model.compile and Model.fit to cover the corresponding API in the tests. * Remove unused variable. * Fixed line lengths in doc string. * Marked test which uses fit method to require a trainable backend. * Docs: - Adapted according to review. - Additionally, replaced "length" with "size" when referring to total number of elements. * Fixed line length.
1 parent b9ff57a commit a11ef39

File tree

2 files changed

+48
-17
lines changed

2 files changed

+48
-17
lines changed

keras/src/layers/reshaping/reshape.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@ class Reshape(Layer):
1111
1212
Args:
1313
target_shape: Target shape. Tuple of integers, does not include the
14-
samples dimension (batch size).
14+
samples dimension (batch size). One element of the `target_shape`
15+
can be -1 in which case the missing value is inferred from the
16+
size of the array and remaining dimensions.
1517
1618
Input shape:
17-
Arbitrary, although all dimensions in the input shape must be
18-
known/fixed. Use the keyword argument `input_shape` (tuple of integers,
19-
does not include the samples/batch size axis) when using this layer as
20-
the first layer in a model.
19+
Arbitrary, but required to be compatible with `target_shape`.
2120
2221
Output shape:
2322
`(batch_size, *target_shape)`
@@ -29,15 +28,23 @@ class Reshape(Layer):
2928
>>> y.shape
3029
(None, 3, 4)
3130
32-
>>> # also supports shape inference using `-1` as dimension
31+
>>> # another example with shape inference using `-1` as dimension
3332
>>> y = keras.layers.Reshape((-1, 2, 2))(x)
3433
>>> y.shape
3534
(None, 3, 2, 2)
3635
"""
3736

3837
def __init__(self, target_shape, **kwargs):
3938
super().__init__(**kwargs)
40-
self.target_shape = tuple(target_shape)
39+
target_shape = tuple(target_shape)
40+
# test validity of target_shape
41+
if target_shape.count(-1) > 1:
42+
raise ValueError(
43+
"The `target_shape` argument must not contain more than one "
44+
f"`-1` value. Received: target_shape={target_shape}"
45+
)
46+
self.target_shape = target_shape
47+
self.built = True
4148

4249
def compute_output_shape(self, input_shape):
4350
return (
@@ -53,17 +60,17 @@ def compute_output_spec(self, inputs):
5360
shape=output_shape, dtype=inputs.dtype, sparse=inputs.sparse
5461
)
5562

56-
def build(self, input_shape):
57-
sample_output_shape = operation_utils.compute_reshape_output_shape(
58-
input_shape[1:], self.target_shape, "target_shape"
63+
def call(self, inputs):
64+
potentially_resolved_target_shape = (
65+
operation_utils.compute_reshape_output_shape(
66+
tuple(inputs.shape)[1:], self.target_shape, "target_shape"
67+
)
5968
)
60-
self._resolved_target_shape = tuple(
61-
-1 if d is None else d for d in sample_output_shape
69+
potentially_resolved_target_shape = tuple(
70+
-1 if d is None else d for d in potentially_resolved_target_shape
6271
)
63-
64-
def call(self, inputs):
6572
return ops.reshape(
66-
inputs, (ops.shape(inputs)[0],) + self._resolved_target_shape
73+
inputs, (ops.shape(inputs)[0],) + potentially_resolved_target_shape
6774
)
6875

6976
def get_config(self):

keras/src/layers/reshaping/reshape_test.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import pytest
22
from absl.testing import parameterized
33

4+
from keras.src import Sequential
45
from keras.src import backend
56
from keras.src import layers
7+
from keras.src import ops
68
from keras.src import testing
79
from keras.src.backend.common.keras_tensor import KerasTensor
810

@@ -96,14 +98,19 @@ def test_reshape_with_dynamic_batch_size(self):
9698
def test_reshape_with_dynamic_batch_size_and_minus_one(self):
9799
input = KerasTensor((None, 6, 4))
98100
layer = layers.Reshape((-1, 8))
99-
layer.build(input.shape)
100101
reshaped = backend.compute_output_spec(layer.__call__, input)
101102
self.assertEqual(reshaped.shape, (None, 3, 8))
102103

104+
def test_reshape_layer_with_varying_input_size_and_minus_one(self):
105+
layer = layers.Reshape((-1, 8))
106+
res = layer(ops.ones((1, 6, 4), dtype="float32"))
107+
self.assertEqual(res.shape, (1, 3, 8))
108+
res = layer(ops.ones((1, 10, 4), dtype="float32"))
109+
self.assertEqual(res.shape, (1, 5, 8))
110+
103111
def test_reshape_with_dynamic_dim_and_minus_one(self):
104112
input = KerasTensor((4, 6, None, 3))
105113
layer = layers.Reshape((-1, 3))
106-
layer.build(input.shape)
107114
reshaped = backend.compute_output_spec(layer.__call__, input)
108115
self.assertEqual(reshaped.shape, (4, None, 3))
109116

@@ -112,3 +119,20 @@ def test_reshape_sets_static_shape(self):
112119
reshaped = layers.Reshape((3, 5))(input_layer)
113120
# Also make sure the batch dim is not lost after reshape.
114121
self.assertEqual(reshaped.shape, (2, 3, 5))
122+
123+
@pytest.mark.requires_trainable_backend
124+
def test_reshape_model_fit_with_varying_input_size_and_minus_one(self):
125+
def generator():
126+
yield (
127+
ops.ones((1, 12, 2), dtype="float32"),
128+
ops.zeros((1, 3, 8), dtype="float32"),
129+
)
130+
yield (
131+
ops.ones((1, 20, 2), dtype="float32"),
132+
ops.zeros((1, 5, 8), dtype="float32"),
133+
)
134+
135+
layer = layers.Reshape((-1, 8))
136+
model = Sequential([layer])
137+
model.compile(loss="mean_squared_error")
138+
model.fit(generator(), steps_per_epoch=2, epochs=1)

0 commit comments

Comments
 (0)