Skip to content

Commit 9ae9059

Browse files
hisushantasayakpaulDN6yiyixuxu
authored
Replacing the nn.Mish activation function with a get_activation function. (huggingface#5651)
* I added a new doc string to the class. This is more flexible to understanding other developers what are doing and where it's using. * Update src/diffusers/models/unet_2d_blocks.py This changes suggest by maintener. Co-authored-by: Sayak Paul <[email protected]> * Update src/diffusers/models/unet_2d_blocks.py Add suggested text Co-authored-by: Sayak Paul <[email protected]> * Update unet_2d_blocks.py I changed the Parameter to Args text. * Update unet_2d_blocks.py proper indentation set in this file. * Update unet_2d_blocks.py a little bit of change in the act_fun argument line. * I run the black command to reformat style in the code * Update unet_2d_blocks.py similar doc-string add to have in the original diffusion repository. * I removed the dummy variable defined in both the encoder and decoder. * Now, I run black package to reformat my file * Remove the redundant line from the adapter.py file. * Black package using to reformated my file * Replacing the nn.Mish activation function with a get_activation function allows developers to more easily choose the right activation function for their task. Additionally, removing redundant variables can improve code readability and maintainability. * I try to fix this: Fast tests for PRs / Fast PyTorch Models & Schedulers CPU tests (pull_request) * Update src/diffusers/models/resnet.py Co-authored-by: YiYi Xu <[email protected]> --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Dhruv Nair <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent 7942bb8 commit 9ae9059

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

src/diffusers/models/resnet.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -778,16 +778,22 @@ class Conv1dBlock(nn.Module):
778778
out_channels (`int`): Number of output channels.
779779
kernel_size (`int` or `tuple`): Size of the convolving kernel.
780780
n_groups (`int`, default `8`): Number of groups to separate the channels into.
781+
activation (`str`, defaults `mish`): Name of the activation function.
781782
"""
782783

783784
def __init__(
784-
self, inp_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], n_groups: int = 8
785+
self,
786+
inp_channels: int,
787+
out_channels: int,
788+
kernel_size: Union[int, Tuple[int, int]],
789+
n_groups: int = 8,
790+
activation: str = "mish",
785791
):
786792
super().__init__()
787793

788794
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
789795
self.group_norm = nn.GroupNorm(n_groups, out_channels)
790-
self.mish = nn.Mish()
796+
self.mish = get_activation(activation)
791797

792798
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
793799
intermediate_repr = self.conv1d(inputs)
@@ -808,16 +814,22 @@ class ResidualTemporalBlock1D(nn.Module):
808814
out_channels (`int`): Number of output channels.
809815
embed_dim (`int`): Embedding dimension.
810816
kernel_size (`int` or `tuple`): Size of the convolving kernel.
817+
activation (`str`, defaults `mish`): It is possible to choose the right activation function.
811818
"""
812819

813820
def __init__(
814-
self, inp_channels: int, out_channels: int, embed_dim: int, kernel_size: Union[int, Tuple[int, int]] = 5
821+
self,
822+
inp_channels: int,
823+
out_channels: int,
824+
embed_dim: int,
825+
kernel_size: Union[int, Tuple[int, int]] = 5,
826+
activation: str = "mish",
815827
):
816828
super().__init__()
817829
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
818830
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
819831

820-
self.time_emb_act = nn.Mish()
832+
self.time_emb_act = get_activation(activation)
821833
self.time_emb = nn.Linear(embed_dim, out_channels)
822834

823835
self.residual_conv = (

src/diffusers/models/vq_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[
162162
If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
163163
is returned.
164164
"""
165-
x = sample
166-
h = self.encode(x).latents
165+
166+
h = self.encode(sample).latents
167167
dec = self.decode(h).sample
168168

169169
if not return_dict:

0 commit comments

Comments
 (0)