Skip to content

Commit ccd89db

Browse files
marcenacpThe TensorFlow Datasets Authors
authored andcommitted
Implement decode_example_np for tfds.core.features.Dataset.
PiperOrigin-RevId: 569395247
1 parent 4517d3e commit ccd89db

File tree

6 files changed

+335
-4
lines changed

6 files changed

+335
-4
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# coding=utf-8
2+
# Copyright 2023 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Python DataSource base class."""
17+
18+
from __future__ import annotations
19+
20+
from collections.abc import Sequence
21+
import dataclasses
22+
from typing import Any, Callable
23+
24+
25+
@dataclasses.dataclass
26+
class PythonDataSource(Sequence):
27+
"""Python data source backed by Python objects: length and __getitem__."""
28+
29+
length: int
30+
# If you have pickling issues for this function, define it in the upper scope
31+
# or get inspiration from _getitem in
32+
# tensorflow_datasets/core/features/dataset_feature.py.
33+
getitem: Callable[[int], Any]
34+
35+
def __len__(self) -> int:
36+
return self.length
37+
38+
def __iter__(self):
39+
for i in range(self.length):
40+
yield self[i]
41+
42+
def __getitem__(self, i: int) -> Any:
43+
return self.getitem(i)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# coding=utf-8
2+
# Copyright 2023 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for the PythonDataSource."""
17+
18+
import pickle
19+
20+
from tensorflow_datasets.core.data_sources import python
21+
22+
23+
def getitem(i):
24+
return i
25+
26+
27+
def test_create_a_python_data_source():
28+
source = python.PythonDataSource(length=2, getitem=getitem)
29+
assert len(source) == 2
30+
assert source[0] == 0
31+
assert source[1] == 1
32+
assert source[2] == 2
33+
34+
35+
def test_iterate_on_a_python_data_source():
36+
source = python.PythonDataSource(length=42, getitem=getitem)
37+
i = 0
38+
for i, j in enumerate(iter(source)):
39+
assert i == j
40+
assert i == 41
41+
42+
43+
def test_python_data_source_is_pickable():
44+
def _getitem(i):
45+
return i
46+
47+
source = python.PythonDataSource(length=42, getitem=_getitem)
48+
source = pickle.loads(pickle.dumps(source))
49+
assert source[0] == 0

tensorflow_datasets/core/decode/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,18 @@ def decode_example(self, serialized_example):
8484
"""
8585
raise NotImplementedError('Abstract class')
8686

87+
def decode_example_np(self, serialized_example):
88+
"""Decode the example feature field for NumPy (eg: image).
89+
90+
Args:
91+
serialized_example: `np.array` as decoded, the dtype/shape should be
92+
identical to `feature.get_serialized_info()`.
93+
94+
Returns:
95+
example: Decoded example. Defaults to `decode_example`.
96+
"""
97+
return self.decode_example(serialized_example)
98+
8799
def decode_batch_example(self, serialized_example):
88100
"""See `FeatureConnector.decode_batch_example` for details."""
89101
return tf.map_fn(

tensorflow_datasets/core/features/dataset_feature.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,33 @@
1616
"""Dataset feature for nested datasets."""
1717
from __future__ import annotations
1818

19+
import dataclasses
1920
import functools
20-
from typing import Any, Dict, Iterator, Union
21+
from typing import Any, Callable, Dict, Iterator, Union
2122

23+
from tensorflow_datasets.core.data_sources import python
2224
from tensorflow_datasets.core.features import feature as feature_lib
2325
from tensorflow_datasets.core.features import sequence_feature
26+
from tensorflow_datasets.core.features import tensor_feature
27+
from tensorflow_datasets.core.features import top_level_feature
2428
from tensorflow_datasets.core.utils import py_utils
2529
from tensorflow_datasets.core.utils import tree_utils
2630
from tensorflow_datasets.core.utils import type_utils
2731
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
2832

2933

34+
@dataclasses.dataclass(frozen=True)
35+
class _getitem: # pylint: disable=invalid-name
36+
"""A pickable version of getitem that can be fed to Beam pipelines."""
37+
38+
decode_fn: Callable[[Any], Any]
39+
nest: Callable[[Any], Any]
40+
flat_example: list[Any]
41+
42+
def __call__(self, i):
43+
return self.decode_fn(self.nest([v[i] for v in self.flat_example]))
44+
45+
3046
class Dataset(sequence_feature.Sequence):
3147
"""A Dataset feature encodes a nested dataset.
3248
@@ -35,6 +51,12 @@ class Dataset(sequence_feature.Sequence):
3551
top-level `tf.data.Dataset` returned by `tfds.load`. At generation time, an
3652
iterable over the dataset elements is given.
3753
54+
If you use tfds.data_source and the NumPy path, `Dataset` will return
55+
a [data
56+
source](https://www.tensorflow.org/datasets/api_docs/python/tfds/data_source).
57+
The advantage of having a data source is that decoding will be lazily executed
58+
when you access each example in the dataset.
59+
3860
This is an experimental feature. Currently, only one level of nesting is
3961
supported and TF1 graph is not supported either.
4062
@@ -146,6 +168,54 @@ def decode_example(self, serialized_example, decoders=None):
146168
)
147169
return ds
148170

171+
def decode_example_np(
172+
self, serialized_example, decoders=None
173+
) -> python.PythonDataSource:
174+
"""See base class for details."""
175+
flatten = self.feature._flatten # pylint: disable=protected-access
176+
nest = self.feature._nest # pylint: disable=protected-access
177+
flat_example = flatten(serialized_example)
178+
flat_features = flatten(self.feature)
179+
num_slices: int | None = None
180+
181+
# First discover the number of slices in the Dataset. Notably, it's possible
182+
# that tensors have to be reshaped. We call slice a record in the Dataset.
183+
# We don't use `example` to avoid confusion with the `serialized_example`.
184+
for i, feature in enumerate(flat_features):
185+
if isinstance(feature, tensor_feature.Tensor) and feature.shape:
186+
try:
187+
flat_example[i] = flat_example[i].reshape((-1,) + feature.shape)
188+
except ValueError as e:
189+
raise ValueError(
190+
"The length of all elements of one slice should be the same."
191+
) from e
192+
feature_num_slices = flat_example[i].shape[0]
193+
else:
194+
feature_num_slices = len(flat_example[i])
195+
if num_slices is None:
196+
num_slices = feature_num_slices
197+
else:
198+
if feature_num_slices != num_slices:
199+
raise ValueError(
200+
"The length of elements of all slices should be the same. Got"
201+
f" {num_slices} and {feature_num_slices}"
202+
)
203+
if num_slices is None:
204+
raise ValueError("no feature was found.")
205+
206+
# Then, we apply the decode function on each slice.
207+
if isinstance(self.feature, top_level_feature.TopLevelFeature):
208+
# Only top-level features accept decoders.
209+
decode_fn = functools.partial(
210+
self.feature.decode_example_np, decoders=decoders
211+
)
212+
else:
213+
decode_fn = self.feature.decode_example_np
214+
215+
return python.PythonDataSource(
216+
length=num_slices, getitem=_getitem(decode_fn, nest, flat_example)
217+
)
218+
149219
def _flatten(self, x):
150220
"""See base class for details."""
151221
return [x]

0 commit comments

Comments
 (0)