88from captum .module .binary_concrete_stochastic_gates import BinaryConcreteStochasticGates
99from captum .testing .helpers import BaseTest
1010from captum .testing .helpers .basic import assertTensorAlmostEqual
11- from parameterized import parameterized_class
1211
1312
14- @parameterized_class (
15- [
16- {"testing_device" : "cpu" },
17- {"testing_device" : "cuda" },
18- ]
19- )
2013class TestBinaryConcreteStochasticGates (BaseTest ):
21- # pyre-fixme[13]: Attribute `testing_device` is never initialized.
22- testing_device : str
14+ testing_device : str = "cpu"
2315
2416 def setUp (self ) -> None :
2517 super ().setUp ()
26- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
27- # `testing_device`.
2818 if self .testing_device == "cuda" and not torch .cuda .is_available ():
2919 raise unittest .SkipTest ("Skipping GPU test since CUDA not available." )
3020
3121 def test_bcstg_1d_input (self ) -> None :
3222
3323 dim = 3
34- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
35- # `testing_device`.
3624 bcstg = BinaryConcreteStochasticGates (dim ).to (self .testing_device )
3725 input_tensor = torch .tensor (
3826 [
@@ -57,8 +45,6 @@ def test_bcstg_1d_input_with_reg_reduction(self) -> None:
5745
5846 dim = 3
5947 mean_bcstg = BinaryConcreteStochasticGates (dim , reg_reduction = "mean" ).to (
60- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
61- # `testing_device`.
6248 self .testing_device
6349 )
6450 none_bcstg = BinaryConcreteStochasticGates (dim , reg_reduction = "none" ).to (
@@ -82,8 +68,6 @@ def test_bcstg_1d_input_with_reg_reduction(self) -> None:
8268 def test_bcstg_1d_input_with_n_gates_error (self ) -> None :
8369
8470 dim = 3
85- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
86- # `testing_device`.
8771 bcstg = BinaryConcreteStochasticGates (dim ).to (self .testing_device )
8872 input_tensor = torch .tensor ([0.0 , 0.1 , 0.2 ]).to (self .testing_device )
8973
@@ -95,14 +79,10 @@ def test_bcstg_num_mask_not_equal_dim_error(self) -> None:
9579 mask = torch .tensor ([0 , 0 , 1 ]) # only two distinct masks, but given dim is 3
9680
9781 with self .assertRaises (AssertionError ):
98- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
99- # `testing_device`.
10082 BinaryConcreteStochasticGates (dim , mask = mask ).to (self .testing_device )
10183
10284 def test_gates_values_matching_dim_when_eval (self ) -> None :
10385 dim = 3
104- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
105- # `testing_device`.
10686 bcstg = BinaryConcreteStochasticGates (dim ).to (self .testing_device )
10787 input_tensor = torch .tensor (
10888 [
@@ -118,8 +98,6 @@ def test_gates_values_matching_dim_when_eval(self) -> None:
11898 def test_bcstg_1d_input_with_mask (self ) -> None :
11999
120100 dim = 2
121- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
122- # `testing_device`.
123101 mask = torch .tensor ([0 , 0 , 1 ]).to (self .testing_device )
124102 bcstg = BinaryConcreteStochasticGates (dim , mask = mask ).to (self .testing_device )
125103 input_tensor = torch .tensor (
@@ -144,8 +122,6 @@ def test_bcstg_1d_input_with_mask(self) -> None:
144122 def test_bcstg_2d_input (self ) -> None :
145123
146124 dim = 3 * 2
147- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
148- # `testing_device`.
149125 bcstg = BinaryConcreteStochasticGates (dim ).to (self .testing_device )
150126
151127 # shape(2,3,2)
@@ -185,8 +161,6 @@ def test_bcstg_2d_input(self) -> None:
185161 def test_bcstg_2d_input_with_n_gates_error (self ) -> None :
186162
187163 dim = 5
188- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
189- # `testing_device`.
190164 bcstg = BinaryConcreteStochasticGates (dim ).to (self .testing_device )
191165 input_tensor = torch .tensor (
192166 [
@@ -210,8 +184,6 @@ def test_bcstg_2d_input_with_mask(self) -> None:
210184 [1 , 1 ],
211185 [0 , 2 ],
212186 ]
213- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
214- # `testing_device`.
215187 ).to (self .testing_device )
216188 bcstg = BinaryConcreteStochasticGates (dim , mask = mask ).to (self .testing_device )
217189
@@ -252,8 +224,6 @@ def test_bcstg_2d_input_with_mask(self) -> None:
252224 def test_get_gate_values_1d_input (self ) -> None :
253225
254226 dim = 3
255- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
256- # `testing_device`.
257227 bcstg = BinaryConcreteStochasticGates (dim ).to (self .testing_device )
258228 input_tensor = torch .tensor (
259229 [
@@ -273,8 +243,6 @@ def test_get_gate_values_1d_input_with_mask(self) -> None:
273243
274244 dim = 2
275245 mask = torch .tensor ([0 , 1 , 1 ])
276- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
277- # `testing_device`.
278246 bcstg = BinaryConcreteStochasticGates (dim , mask = mask ).to (self .testing_device )
279247 input_tensor = torch .tensor (
280248 [
@@ -293,8 +261,6 @@ def test_get_gate_values_1d_input_with_mask(self) -> None:
293261 def test_get_gate_values_2d_input (self ) -> None :
294262
295263 dim = 3 * 2
296- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
297- # `testing_device`.
298264 bcstg = BinaryConcreteStochasticGates (dim ).to (self .testing_device )
299265
300266 # shape(2,3,2)
@@ -326,8 +292,6 @@ def test_get_gate_values_clamp(self) -> None:
326292 torch .tensor ([10.0 , - 10.0 , 10.0 ]),
327293 lower_bound = - 2 ,
328294 upper_bound = 2 ,
329- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
330- # `testing_device`.
331295 ).to (self .testing_device )
332296
333297 clamped_gate_values = bcstg .get_gate_values ().cpu ().tolist ()
@@ -350,8 +314,6 @@ def test_get_gate_values_2d_input_with_mask(self) -> None:
350314 [0 , 2 ],
351315 ]
352316 )
353- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
354- # `testing_device`.
355317 bcstg = BinaryConcreteStochasticGates (dim , mask = mask ).to (self .testing_device )
356318
357319 input_tensor = torch .tensor (
@@ -379,8 +341,6 @@ def test_get_gate_values_2d_input_with_mask(self) -> None:
379341 def test_get_gate_active_probs_1d_input (self ) -> None :
380342
381343 dim = 3
382- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
383- # `testing_device`.
384344 bcstg = BinaryConcreteStochasticGates (dim ).to (self .testing_device )
385345 input_tensor = torch .tensor (
386346 [
@@ -402,8 +362,6 @@ def test_get_gate_active_probs_1d_input_with_mask(self) -> None:
402362
403363 dim = 2
404364 mask = torch .tensor ([0 , 1 , 1 ])
405- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
406- # `testing_device`.
407365 bcstg = BinaryConcreteStochasticGates (dim , mask = mask ).to (self .testing_device )
408366 input_tensor = torch .tensor (
409367 [
@@ -424,8 +382,6 @@ def test_get_gate_active_probs_1d_input_with_mask(self) -> None:
424382 def test_get_gate_active_probs_2d_input (self ) -> None :
425383
426384 dim = 3 * 2
427- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
428- # `testing_device`.
429385 bcstg = BinaryConcreteStochasticGates (dim ).to (self .testing_device )
430386
431387 # shape(2,3,2)
@@ -463,8 +419,6 @@ def test_get_gate_active_probs_2d_input_with_mask(self) -> None:
463419 [0 , 2 ],
464420 ]
465421 )
466- # pyre-fixme[16]: `TestBinaryConcreteStochasticGates` has no attribute
467- # `testing_device`.
468422 bcstg = BinaryConcreteStochasticGates (dim , mask = mask ).to (self .testing_device )
469423
470424 input_tensor = torch .tensor (
0 commit comments