Skip to content

Commit 5fed6fd

Browse files
Fix pytorch loader (#812)
* Throw error if can't apply preprocessing. Fix pytorch label conversion * black formatting * Disabling preprocessing durings tests * Raise error if value is not a tf tensor * Fixing broken tf1 framework dataset test
1 parent 5588582 commit 5fed6fd

File tree

4 files changed

+43
-7
lines changed

4 files changed

+43
-7
lines changed

armory/data/datasets.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,13 @@ def _generator_from_tfds(
268268
ds = ds.batch(batch_size, drop_remainder=False)
269269
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
270270

271+
if framework != "numpy" and (
272+
preprocessing_fn is not None or label_preprocessing_fn is not None
273+
):
274+
raise ValueError(
275+
f"Data/label preprocessing functions only supported for numpy framework. Selected {framework} framework"
276+
)
277+
271278
if framework == "numpy":
272279
ds = tfds.as_numpy(ds, graph=default_graph)
273280
generator = ArmoryDataGenerator(

armory/data/pytorch_loader.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import tensorflow as tf
23

34

45
class TFToTorchGenerator(torch.utils.data.IterableDataset):
@@ -9,12 +10,26 @@ def __init__(self, tf_dataset):
910
def __iter__(self):
1011
for ex in self.tf_dataset.take(-1):
1112
x, y = ex
12-
# manually handle adverarial dataset
13+
# separately handle benign/adversarial data formats
1314
if isinstance(x, tuple):
14-
x = (torch.from_numpy(x[0].numpy()), torch.from_numpy(x[1].numpy()))
15-
y = torch.from_numpy(y.numpy())
16-
# non-adversarial dataset
15+
x_torch = (
16+
torch.from_numpy(x[0].numpy()),
17+
torch.from_numpy(x[1].numpy()),
18+
)
1719
else:
18-
x = torch.from_numpy(x.numpy())
19-
y = torch.from_numpy(y.numpy())
20-
yield x, y
20+
x_torch = torch.from_numpy(x.numpy())
21+
22+
# separately handle tensor/object detection label formats
23+
if isinstance(y, dict):
24+
y_torch = {}
25+
for k, v in y.items():
26+
if isinstance(v, tf.Tensor):
27+
y_torch[k] = torch.from_numpy(v.numpy())
28+
else:
29+
raise ValueError(
30+
f"Expected all values to be of type tf.Tensor, but value at key {k} is of type {type(v)}"
31+
)
32+
else:
33+
y_torch = torch.from_numpy(y.numpy())
34+
35+
yield x_torch, y_torch

tests/test_pytorch/test_framework_dataset.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def test_pytorch_generator_cifar10():
1919
batch_size=batch_size,
2020
dataset_dir=DATASET_DIR,
2121
framework="pytorch",
22+
preprocessing_fn=None,
23+
fit_preprocessing_fn=None,
2224
)
2325

2426
assert isinstance(dataset, torch.utils.data.DataLoader)
@@ -38,6 +40,8 @@ def test_pytorch_generator_mnist():
3840
batch_size=batch_size,
3941
dataset_dir=DATASET_DIR,
4042
framework="pytorch",
43+
preprocessing_fn=None,
44+
fit_preprocessing_fn=None,
4145
)
4246

4347
assert isinstance(dataset, torch.utils.data.DataLoader)
@@ -57,6 +61,8 @@ def test_pytorch_generator_resisc():
5761
batch_size=batch_size,
5862
dataset_dir=DATASET_DIR,
5963
framework="pytorch",
64+
preprocessing_fn=None,
65+
fit_preprocessing_fn=None,
6066
)
6167

6268
assert isinstance(dataset, torch.utils.data.DataLoader)
@@ -76,6 +82,8 @@ def test_pytorch_generator_epochs():
7682
batch_size=batch_size,
7783
dataset_dir=DATASET_DIR,
7884
framework="pytorch",
85+
preprocessing_fn=None,
86+
fit_preprocessing_fn=None,
7987
)
8088

8189
cnt = 0
@@ -100,6 +108,8 @@ def test_tf_pytorch_equality():
100108
dataset_dir=DATASET_DIR,
101109
framework="tf",
102110
shuffle_files=False,
111+
preprocessing_fn=None,
112+
fit_preprocessing_fn=None,
103113
)
104114

105115
ds_pytorch = iter(
@@ -109,6 +119,8 @@ def test_tf_pytorch_equality():
109119
dataset_dir=DATASET_DIR,
110120
framework="pytorch",
111121
shuffle_files=False,
122+
preprocessing_fn=None,
123+
fit_preprocessing_fn=None,
112124
)
113125
)
114126

tests/test_tf1/test_framework_dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,7 @@ def test_tf_generator():
1717
batch_size=16,
1818
dataset_dir=DATASET_DIR,
1919
framework="tf",
20+
preprocessing_fn=None,
21+
fit_preprocessing_fn=None,
2022
)
2123
assert isinstance(dataset, (tf.compat.v2.data.Dataset, tf.compat.v1.data.Dataset))

0 commit comments

Comments
 (0)