Skip to content

Commit 1d37f42

Browse files
zRzRzRzRzRzRzROleehyOyiyixuxugithub-actions[bot]
authored
Modify the implementation of retrieve_timesteps in CogView4-Control. (#11125)
* 1 * change to channel 1 * cogview4 control training * add CacheMixin * 1 * remove initial_input_channels change for val * 1 * update * use 3.5 * new loss * 1 * use imagetoken * for megatron convert * 1 * train con and uc * 2 * remove guidance_scale * Update pipeline_cogview4_control.py * fix * use cogview4 pipeline with timestep * update shift_factor * remove the uncond * add max length * change convert and use GLMModel instead of GLMForCasualLM * fix * [cogview4] Add attention mask support to transformer model * [fix] Add attention mask for padded token * update * remove padding type * Update train_control_cogview4.py * resolve conflicts with #10981 * add control convert * use control format * fix * add missing import * update with cogview4 formate * make style * Update pipeline_cogview4_control.py * Update pipeline_cogview4_control.py * remove * Update pipeline_cogview4_control.py * put back * Apply style fixes --------- Co-authored-by: OleehyO <[email protected]> Co-authored-by: yiyixuxu <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 0213179 commit 1d37f42

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def calculate_shift(
6868
return mu
6969

7070

71-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
71+
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps
7272
def retrieve_timesteps(
7373
scheduler,
7474
num_inference_steps: Optional[int] = None,
@@ -100,10 +100,19 @@ def retrieve_timesteps(
100100
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
101101
second element is the number of inference steps.
102102
"""
103+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
104+
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
105+
103106
if timesteps is not None and sigmas is not None:
104-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
105-
if timesteps is not None:
106-
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
107+
if not accepts_timesteps and not accepts_sigmas:
108+
raise ValueError(
109+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
110+
f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
111+
)
112+
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
113+
timesteps = scheduler.timesteps
114+
num_inference_steps = len(timesteps)
115+
elif timesteps is not None and sigmas is None:
107116
if not accepts_timesteps:
108117
raise ValueError(
109118
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -112,9 +121,8 @@ def retrieve_timesteps(
112121
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
113122
timesteps = scheduler.timesteps
114123
num_inference_steps = len(timesteps)
115-
elif sigmas is not None:
116-
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
117-
if not accept_sigmas:
124+
elif timesteps is None and sigmas is not None:
125+
if not accepts_sigmas:
118126
raise ValueError(
119127
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
120128
f" sigmas schedules. Please check whether you are using the correct scheduler."
@@ -515,8 +523,8 @@ def __call__(
515523
The output format of the generate image. Choose between
516524
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
517525
return_dict (`bool`, *optional*, defaults to `True`):
518-
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
519-
of a plain tuple.
526+
Whether or not to return a [`~pipelines.pipeline_CogView4.CogView4PipelineOutput`] instead of a plain
527+
tuple.
520528
attention_kwargs (`dict`, *optional*):
521529
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
522530
`self.processor` in
@@ -532,7 +540,6 @@ def __call__(
532540
`._callback_tensor_inputs` attribute of your pipeline class.
533541
max_sequence_length (`int`, defaults to `224`):
534542
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
535-
536543
Examples:
537544
538545
Returns:

0 commit comments

Comments
 (0)