Skip to content

Commit 55487f3

Browse files
javidcffchollet
authored andcommitted
Added dtype parameter to zeros_like and ones_like (keras-team#5062)
* Fixed checking input masks in Layer.compute_mask * Added dtype parameter to zeros_like and ones_like * Fix existing docstring for ones_like and zeros_like
1 parent 1c6db08 commit 55487f3

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

keras/backend/tensorflow_backend.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -534,15 +534,17 @@ def eye(size, dtype=None, name=None):
534534
return variable(np.eye(size), dtype, name)
535535

536536

537-
def zeros_like(x, name=None):
537+
def zeros_like(x, dtype=None, name=None):
538538
"""Instantiates an all-zeros Keras variable
539539
of the same shape as another Keras variable or tensor and returns it.
540540
541541
# Arguments
542542
x: Keras variable or Keras tensor.
543+
dtype: String, dtype of returned Keras variable.
544+
None uses the dtype of x.
543545
544546
# Returns
545-
A Keras variable, filled with `0.0`.
547+
A Keras variable with the shape of x filled with zeros.
546548
547549
# Example
548550
```python
@@ -554,18 +556,20 @@ def zeros_like(x, name=None):
554556
[ 0., 0., 0.]], dtype=float32)
555557
```
556558
"""
557-
return tf.zeros_like(x, name=name)
559+
return tf.zeros_like(x, dtype=dtype, name=name)
558560

559561

560-
def ones_like(x, name=None):
562+
def ones_like(x, dtype=None, name=None):
561563
"""Instantiates an all-ones Keras variable
562564
of the same shape as another Keras variable or tensor and returns it.
563565
564566
# Arguments
565567
x: Keras variable or tensor.
568+
dtype: String, dtype of returned Keras variable.
569+
None uses the dtype of x.
566570
567571
# Returns
568-
A Keras variable, filled with `1.0`.
572+
A Keras variable with the shape of x filled with ones.
569573
570574
# Example
571575
```python
@@ -577,7 +581,7 @@ def ones_like(x, name=None):
577581
[ 1., 1., 1.]], dtype=float32)
578582
```
579583
"""
580-
return tf.ones_like(x, name=name)
584+
return tf.ones_like(x, dtype=dtype, name=name)
581585

582586

583587
def random_uniform_variable(shape, low, high, dtype=None,

keras/backend/theano_backend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,12 @@ def eye(size, dtype=None, name=None):
176176
return variable(np.eye(size), dtype, name)
177177

178178

179-
def ones_like(x, name=None):
180-
return T.ones_like(x)
179+
def ones_like(x, dtype=None, name=None):
180+
return T.ones_like(x, dtype=dtype)
181181

182182

183-
def zeros_like(x, name=None):
184-
return T.zeros_like(x)
183+
def zeros_like(x, dtype=None, name=None):
184+
return T.zeros_like(x, dtype=dtype)
185185

186186

187187
def random_uniform_variable(shape, low, high, dtype=None, name=None):

0 commit comments

Comments
 (0)