Skip to content

Modular custom config object serialization #11868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,10 @@ def to_json_saveable(value):
value = value.tolist()
elif isinstance(value, Path):
value = value.as_posix()
elif hasattr(value, "to_dict") and callable(value.to_dict):
value = value.to_dict()
elif isinstance(value, list):
value = [to_json_saveable(v) for v in value]
return value

if "quantization_config" in config_dict:
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/guiders/auto_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
self,
guidance_scale: float = 7.5,
auto_guidance_layers: Optional[Union[int, List[int]]] = None,
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
dropout: Optional[float] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
Expand Down Expand Up @@ -104,13 +104,18 @@ def __init__(
LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers
]

if isinstance(auto_guidance_config, dict):
auto_guidance_config = LayerSkipConfig.from_dict(auto_guidance_config)

if isinstance(auto_guidance_config, LayerSkipConfig):
auto_guidance_config = [auto_guidance_config]

if not isinstance(auto_guidance_config, list):
raise ValueError(
f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}."
)
elif isinstance(next(iter(auto_guidance_config), None), dict):
auto_guidance_config = [LayerSkipConfig.from_dict(config) for config in auto_guidance_config]

self.auto_guidance_config = auto_guidance_config
self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))]
Expand Down
46 changes: 37 additions & 9 deletions src/diffusers/guiders/perturbed_attention_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from ..configuration_utils import register_to_config
from ..hooks import LayerSkipConfig
from ..utils import get_logger
from .skip_layer_guidance import SkipLayerGuidance


logger = get_logger(__name__) # pylint: disable=invalid-name


class PerturbedAttentionGuidance(SkipLayerGuidance):
"""
Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
Expand Down Expand Up @@ -48,8 +52,8 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
The fraction of the total number of denoising steps after which perturbed attention guidance stops.
perturbed_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
If not provided, `skip_layer_config` must be provided.
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
If not provided, `perturbed_guidance_config` must be provided.
perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
Expand Down Expand Up @@ -79,36 +83,60 @@ def __init__(
perturbed_guidance_start: float = 0.01,
perturbed_guidance_stop: float = 0.2,
perturbed_guidance_layers: Optional[Union[int, List[int]]] = None,
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
if skip_layer_config is None:
if perturbed_guidance_config is None:
if perturbed_guidance_layers is None:
raise ValueError(
"`perturbed_guidance_layers` must be provided if `skip_layer_config` is not specified."
"`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified."
)
skip_layer_config = LayerSkipConfig(
perturbed_guidance_config = LayerSkipConfig(
indices=perturbed_guidance_layers,
fqn="auto",
skip_attention=False,
skip_attention_scores=True,
skip_ff=False,
)
else:
if perturbed_guidance_layers is not None:
raise ValueError(
"`perturbed_guidance_layers` should not be provided if `skip_layer_config` is specified."
"`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified."
)

if isinstance(perturbed_guidance_config, dict):
perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config)

if isinstance(perturbed_guidance_config, LayerSkipConfig):
perturbed_guidance_config = [perturbed_guidance_config]

if not isinstance(perturbed_guidance_config, list):
raise ValueError(
"`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`."
)
elif isinstance(next(iter(perturbed_guidance_config), None), dict):
perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config]

for config in perturbed_guidance_config:
if config.skip_attention or not config.skip_attention_scores or config.skip_ff:
logger.warning(
"Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. "
"Please check your configuration. Modifying the config to match the expected values."
)
config.skip_attention = False
config.skip_attention_scores = True
config.skip_ff = False

super().__init__(
guidance_scale=guidance_scale,
skip_layer_guidance_scale=perturbed_guidance_scale,
skip_layer_guidance_start=perturbed_guidance_start,
skip_layer_guidance_stop=perturbed_guidance_stop,
skip_layer_guidance_layers=perturbed_guidance_layers,
skip_layer_config=skip_layer_config,
skip_layer_config=perturbed_guidance_config,
guidance_rescale=guidance_rescale,
use_original_formulation=use_original_formulation,
start=start,
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/guiders/skip_layer_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(
skip_layer_guidance_start: float = 0.01,
skip_layer_guidance_stop: float = 0.2,
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
Expand Down Expand Up @@ -135,13 +135,18 @@ def __init__(
)
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]

if isinstance(skip_layer_config, dict):
skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config)

if isinstance(skip_layer_config, LayerSkipConfig):
skip_layer_config = [skip_layer_config]

if not isinstance(skip_layer_config, list):
raise ValueError(
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
)
elif isinstance(next(iter(skip_layer_config), None), dict):
skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config]

self.skip_layer_config = skip_layer_config
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/guiders/smoothed_energy_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,18 @@ def __init__(
)
seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]

if isinstance(seg_guidance_config, dict):
seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config)

if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
seg_guidance_config = [seg_guidance_config]

if not isinstance(seg_guidance_config, list):
raise ValueError(
f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
)
elif isinstance(next(iter(seg_guidance_config), None), dict):
seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config]

self.seg_guidance_config = seg_guidance_config
self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
Expand Down
9 changes: 8 additions & 1 deletion src/diffusers/hooks/layer_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import math
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import Callable, List, Optional

import torch
Expand Down Expand Up @@ -78,6 +78,13 @@ def __post_init__(self):
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)

def to_dict(self):
return asdict(self)

@staticmethod
def from_dict(data: dict) -> "LayerSkipConfig":
return LayerSkipConfig(**data)


class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
Expand Down
9 changes: 8 additions & 1 deletion src/diffusers/hooks/smoothed_energy_guidance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import math
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import List, Optional

import torch
Expand Down Expand Up @@ -51,6 +51,13 @@ class SmoothedEnergyGuidanceConfig:
fqn: str = "auto"
_query_proj_identifiers: List[str] = None

def to_dict(self):
return asdict(self)

@staticmethod
def from_dict(data: dict) -> "SmoothedEnergyGuidanceConfig":
return SmoothedEnergyGuidanceConfig(**data)


class SmoothedEnergyGuidanceHook(ModelHook):
def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None:
Expand Down