Skip to content

Commit 21bbc63

Browse files
[Attention] Finish refactor attention file (huggingface#1879)
* [Attention] Finish refactor attention file * correct more * fix * more fixes * correct * up
1 parent 62608a9 commit 21bbc63

File tree

7 files changed

+15
-367
lines changed

7 files changed

+15
-367
lines changed

docs/source/api/models.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
5656
[[autodoc]] Transformer2DModel
5757

5858
## Transformer2DModelOutput
59-
[[autodoc]] models.attention.Transformer2DModelOutput
59+
[[autodoc]] models.transformer_2d.Transformer2DModelOutput
6060

6161
## PriorTransformer
6262
[[autodoc]] models.prior_transformer.PriorTransformer

src/diffusers/models/attention.py

Lines changed: 0 additions & 357 deletions
Large diffs are not rendered by default.

src/diffusers/models/dual_transformer_2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ def forward(
119119
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
120120
121121
Returns:
122-
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
123-
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
124-
tensor.
122+
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
123+
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
124+
returning a tuple, the first element is the sample tensor.
125125
"""
126126
input_states = hidden_states
127127

src/diffusers/models/transformer_2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,9 @@ def forward(
189189
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
190190
191191
Returns:
192-
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
193-
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
194-
tensor.
192+
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
193+
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
194+
returning a tuple, the first element is the sample tensor.
195195
"""
196196
# 1. Input
197197
if self.is_input_continuous:

src/diffusers/models/unet_2d_blocks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
import torch
1616
from torch import nn
1717

18-
from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
18+
from .attention import AttentionBlock
1919
from .cross_attention import CrossAttention, CrossAttnAddedKVProcessor
20+
from .dual_transformer_2d import DualTransformer2DModel
2021
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
22+
from .transformer_2d import Transformer2DModel
2123

2224

2325
def get_down_block(

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66

77
from ...configuration_utils import ConfigMixin, register_to_config
88
from ...models import ModelMixin
9-
from ...models.attention import CrossAttention, DualTransformer2DModel, Transformer2DModel
9+
from ...models.attention import CrossAttention
1010
from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor
11+
from ...models.dual_transformer_2d import DualTransformer2DModel
1112
from ...models.embeddings import TimestepEmbedding, Timesteps
13+
from ...models.transformer_2d import Transformer2DModel
1214
from ...models.unet_2d_condition import UNet2DConditionOutput
1315
from ...utils import logging
1416

tests/test_layers_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
import torch
2121
from torch import nn
2222

23-
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock, Transformer2DModel
23+
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock
2424
from diffusers.models.embeddings import get_timestep_embedding
2525
from diffusers.models.resnet import Downsample2D, Upsample2D
26+
from diffusers.models.transformer_2d import Transformer2DModel
2627
from diffusers.utils import torch_device
2728

2829

0 commit comments

Comments
 (0)