Skip to content

Commit 7139f0e

Browse files
authored
fix: norm group test for UNet3D. (huggingface#2959)
1 parent 8c530fc commit 7139f0e

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

tests/models/test_models_unet_3d_condition.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,11 @@ def test_xformers_enable_works(self):
119119
== "XFormersAttnProcessor"
120120
), "xformers is not enabled"
121121

122-
# Overriding because `block_out_channels` needs to be different for this model.
122+
# Overriding to set `norm_num_groups` needs to be different for this model.
123123
def test_forward_with_norm_groups(self):
124124
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
125125

126126
init_dict["norm_num_groups"] = 32
127-
init_dict["block_out_channels"] = (32, 64, 64, 64)
128127

129128
model = self.model_class(**init_dict)
130129
model.to(torch_device)

0 commit comments

Comments
 (0)