File tree Expand file tree Collapse file tree 4 files changed +43
-7
lines changed Expand file tree Collapse file tree 4 files changed +43
-7
lines changed Original file line number Diff line number Diff line change @@ -268,6 +268,13 @@ def _generator_from_tfds(
268
268
ds = ds .batch (batch_size , drop_remainder = False )
269
269
ds = ds .prefetch (tf .data .experimental .AUTOTUNE )
270
270
271
+ if framework != "numpy" and (
272
+ preprocessing_fn is not None or label_preprocessing_fn is not None
273
+ ):
274
+ raise ValueError (
275
+ f"Data/label preprocessing functions only supported for numpy framework. Selected { framework } framework"
276
+ )
277
+
271
278
if framework == "numpy" :
272
279
ds = tfds .as_numpy (ds , graph = default_graph )
273
280
generator = ArmoryDataGenerator (
Original file line number Diff line number Diff line change 1
1
import torch
2
+ import tensorflow as tf
2
3
3
4
4
5
class TFToTorchGenerator (torch .utils .data .IterableDataset ):
@@ -9,12 +10,26 @@ def __init__(self, tf_dataset):
9
10
def __iter__ (self ):
10
11
for ex in self .tf_dataset .take (- 1 ):
11
12
x , y = ex
12
- # manually handle adverarial dataset
13
+ # separately handle benign/adversarial data formats
13
14
if isinstance (x , tuple ):
14
- x = (torch .from_numpy (x [0 ].numpy ()), torch .from_numpy (x [1 ].numpy ()))
15
- y = torch .from_numpy (y .numpy ())
16
- # non-adversarial dataset
15
+ x_torch = (
16
+ torch .from_numpy (x [0 ].numpy ()),
17
+ torch .from_numpy (x [1 ].numpy ()),
18
+ )
17
19
else :
18
- x = torch .from_numpy (x .numpy ())
19
- y = torch .from_numpy (y .numpy ())
20
- yield x , y
20
+ x_torch = torch .from_numpy (x .numpy ())
21
+
22
+ # separately handle tensor/object detection label formats
23
+ if isinstance (y , dict ):
24
+ y_torch = {}
25
+ for k , v in y .items ():
26
+ if isinstance (v , tf .Tensor ):
27
+ y_torch [k ] = torch .from_numpy (v .numpy ())
28
+ else :
29
+ raise ValueError (
30
+ f"Expected all values to be of type tf.Tensor, but value at key { k } is of type { type (v )} "
31
+ )
32
+ else :
33
+ y_torch = torch .from_numpy (y .numpy ())
34
+
35
+ yield x_torch , y_torch
Original file line number Diff line number Diff line change @@ -19,6 +19,8 @@ def test_pytorch_generator_cifar10():
19
19
batch_size = batch_size ,
20
20
dataset_dir = DATASET_DIR ,
21
21
framework = "pytorch" ,
22
+ preprocessing_fn = None ,
23
+ fit_preprocessing_fn = None ,
22
24
)
23
25
24
26
assert isinstance (dataset , torch .utils .data .DataLoader )
@@ -38,6 +40,8 @@ def test_pytorch_generator_mnist():
38
40
batch_size = batch_size ,
39
41
dataset_dir = DATASET_DIR ,
40
42
framework = "pytorch" ,
43
+ preprocessing_fn = None ,
44
+ fit_preprocessing_fn = None ,
41
45
)
42
46
43
47
assert isinstance (dataset , torch .utils .data .DataLoader )
@@ -57,6 +61,8 @@ def test_pytorch_generator_resisc():
57
61
batch_size = batch_size ,
58
62
dataset_dir = DATASET_DIR ,
59
63
framework = "pytorch" ,
64
+ preprocessing_fn = None ,
65
+ fit_preprocessing_fn = None ,
60
66
)
61
67
62
68
assert isinstance (dataset , torch .utils .data .DataLoader )
@@ -76,6 +82,8 @@ def test_pytorch_generator_epochs():
76
82
batch_size = batch_size ,
77
83
dataset_dir = DATASET_DIR ,
78
84
framework = "pytorch" ,
85
+ preprocessing_fn = None ,
86
+ fit_preprocessing_fn = None ,
79
87
)
80
88
81
89
cnt = 0
@@ -100,6 +108,8 @@ def test_tf_pytorch_equality():
100
108
dataset_dir = DATASET_DIR ,
101
109
framework = "tf" ,
102
110
shuffle_files = False ,
111
+ preprocessing_fn = None ,
112
+ fit_preprocessing_fn = None ,
103
113
)
104
114
105
115
ds_pytorch = iter (
@@ -109,6 +119,8 @@ def test_tf_pytorch_equality():
109
119
dataset_dir = DATASET_DIR ,
110
120
framework = "pytorch" ,
111
121
shuffle_files = False ,
122
+ preprocessing_fn = None ,
123
+ fit_preprocessing_fn = None ,
112
124
)
113
125
)
114
126
Original file line number Diff line number Diff line change @@ -17,5 +17,7 @@ def test_tf_generator():
17
17
batch_size = 16 ,
18
18
dataset_dir = DATASET_DIR ,
19
19
framework = "tf" ,
20
+ preprocessing_fn = None ,
21
+ fit_preprocessing_fn = None ,
20
22
)
21
23
assert isinstance (dataset , (tf .compat .v2 .data .Dataset , tf .compat .v1 .data .Dataset ))
You can’t perform that action at this time.
0 commit comments