16
16
"""Dataset feature for nested datasets."""
17
17
from __future__ import annotations
18
18
19
+ import dataclasses
19
20
import functools
20
- from typing import Any , Dict , Iterator , Union
21
+ from typing import Any , Callable , Dict , Iterator , Union
21
22
23
+ from tensorflow_datasets .core .data_sources import python
22
24
from tensorflow_datasets .core .features import feature as feature_lib
23
25
from tensorflow_datasets .core .features import sequence_feature
26
+ from tensorflow_datasets .core .features import tensor_feature
27
+ from tensorflow_datasets .core .features import top_level_feature
24
28
from tensorflow_datasets .core .utils import py_utils
25
29
from tensorflow_datasets .core .utils import tree_utils
26
30
from tensorflow_datasets .core .utils import type_utils
27
31
from tensorflow_datasets .core .utils .lazy_imports_utils import tensorflow as tf
28
32
29
33
34
+ @dataclasses .dataclass (frozen = True )
35
+ class _getitem : # pylint: disable=invalid-name
36
+ """A pickable version of getitem that can be fed to Beam pipelines."""
37
+
38
+ decode_fn : Callable [[Any ], Any ]
39
+ nest : Callable [[Any ], Any ]
40
+ flat_example : list [Any ]
41
+
42
+ def __call__ (self , i ):
43
+ return self .decode_fn (self .nest ([v [i ] for v in self .flat_example ]))
44
+
45
+
30
46
class Dataset (sequence_feature .Sequence ):
31
47
"""A Dataset feature encodes a nested dataset.
32
48
@@ -35,6 +51,12 @@ class Dataset(sequence_feature.Sequence):
35
51
top-level `tf.data.Dataset` returned by `tfds.load`. At generation time, an
36
52
iterable over the dataset elements is given.
37
53
54
+ If you use tfds.data_source and the NumPy path, `Dataset` will return
55
+ a [data
56
+ source](https://www.tensorflow.org/datasets/api_docs/python/tfds/data_source).
57
+ The advantage of having a data source is that decoding will be lazily executed
58
+ when you access each example in the dataset.
59
+
38
60
This is an experimental feature. Currently, only one level of nesting is
39
61
supported and TF1 graph is not supported either.
40
62
@@ -146,6 +168,54 @@ def decode_example(self, serialized_example, decoders=None):
146
168
)
147
169
return ds
148
170
171
+ def decode_example_np (
172
+ self , serialized_example , decoders = None
173
+ ) -> python .PythonDataSource :
174
+ """See base class for details."""
175
+ flatten = self .feature ._flatten # pylint: disable=protected-access
176
+ nest = self .feature ._nest # pylint: disable=protected-access
177
+ flat_example = flatten (serialized_example )
178
+ flat_features = flatten (self .feature )
179
+ num_slices : int | None = None
180
+
181
+ # First discover the number of slices in the Dataset. Notably, it's possible
182
+ # that tensors have to be reshaped. We call slice a record in the Dataset.
183
+ # We don't use `example` to avoid confusion with the `serialized_example`.
184
+ for i , feature in enumerate (flat_features ):
185
+ if isinstance (feature , tensor_feature .Tensor ) and feature .shape :
186
+ try :
187
+ flat_example [i ] = flat_example [i ].reshape ((- 1 ,) + feature .shape )
188
+ except ValueError as e :
189
+ raise ValueError (
190
+ "The length of all elements of one slice should be the same."
191
+ ) from e
192
+ feature_num_slices = flat_example [i ].shape [0 ]
193
+ else :
194
+ feature_num_slices = len (flat_example [i ])
195
+ if num_slices is None :
196
+ num_slices = feature_num_slices
197
+ else :
198
+ if feature_num_slices != num_slices :
199
+ raise ValueError (
200
+ "The length of elements of all slices should be the same. Got"
201
+ f" { num_slices } and { feature_num_slices } "
202
+ )
203
+ if num_slices is None :
204
+ raise ValueError ("no feature was found." )
205
+
206
+ # Then, we apply the decode function on each slice.
207
+ if isinstance (self .feature , top_level_feature .TopLevelFeature ):
208
+ # Only top-level features accept decoders.
209
+ decode_fn = functools .partial (
210
+ self .feature .decode_example_np , decoders = decoders
211
+ )
212
+ else :
213
+ decode_fn = self .feature .decode_example_np
214
+
215
+ return python .PythonDataSource (
216
+ length = num_slices , getitem = _getitem (decode_fn , nest , flat_example )
217
+ )
218
+
149
219
def _flatten (self , x ):
150
220
"""See base class for details."""
151
221
return [x ]
0 commit comments