Skip to content

Commit d36103a

Browse files
[Tests] Speed up test (huggingface#2919)
speed up test
1 parent b3c437e commit d36103a

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

tests/models/test_models_unet_3d_condition.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,17 @@ def output_shape(self):
8888

8989
def prepare_init_args_and_inputs_for_common(self):
9090
init_dict = {
91-
"block_out_channels": (32, 64, 64, 64),
91+
"block_out_channels": (32, 64),
9292
"down_block_types": (
93-
"CrossAttnDownBlock3D",
94-
"CrossAttnDownBlock3D",
9593
"CrossAttnDownBlock3D",
9694
"DownBlock3D",
9795
),
98-
"up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
96+
"up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D"),
9997
"cross_attention_dim": 32,
100-
"attention_head_dim": 4,
98+
"attention_head_dim": 8,
10199
"out_channels": 4,
102100
"in_channels": 4,
103-
"layers_per_block": 2,
101+
"layers_per_block": 1,
104102
"sample_size": 32,
105103
}
106104
inputs_dict = self.dummy_input

0 commit comments

Comments
 (0)