Skip to content

Hunyuan Video Framepack F1 #11534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 12, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
support framepack f1
  • Loading branch information
a-r-r-o-w committed May 9, 2025
commit 99960c4667321ec4e9f891d95e9b620bfa527883
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import inspect
import math
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -232,6 +233,11 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class FramepackInferenceType(str, Enum):
VANILLA = "vanilla"
INVERTED_ANTI_DRIFTING = "inverted_anti_drifting"


class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using HunyuanVideo.
Expand Down Expand Up @@ -455,6 +461,11 @@ def check_inputs(
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
prompt_template=None,
image=None,
image_latents=None,
last_image=None,
last_image_latents=None,
inference_type=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
Expand Down Expand Up @@ -493,6 +504,21 @@ def check_inputs(
f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
)

inference_types = [x.value for x in FramepackInferenceType.__members__.values()]
if inference_type not in inference_types:
raise ValueError(f"`inference_type` has to be one of '{inference_types}' but is '{inference_type}'")

if image is not None and image_latents is not None:
raise ValueError("Only one of `image` or `image_latents` can be passed.")
if last_image is not None and last_image_latents is not None:
raise ValueError("Only one of `last_image` or `last_image_latents` can be passed.")
if inference_type != FramepackInferenceType.INVERTED_ANTI_DRIFTING and (
last_image is not None or last_image_latents is not None
):
raise ValueError(
'Only `"inverted_anti_drifting"` inference type supports `last_image` or `last_image_latents`.'
)

def prepare_latents(
self,
batch_size: int = 1,
Expand Down Expand Up @@ -623,6 +649,7 @@ def __call__(
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
max_sequence_length: int = 256,
inference_type: FramepackInferenceType = FramepackInferenceType.INVERTED_ANTI_DRIFTING,
):
r"""
The call function to the pipeline for generation.
Expand Down Expand Up @@ -735,6 +762,11 @@ def __call__(
prompt_embeds,
callback_on_step_end_tensor_inputs,
prompt_template,
image,
image_latents,
last_image,
last_image_latents,
inference_type,
)

has_neg_prompt = negative_prompt is not None or (
Expand Down Expand Up @@ -806,18 +838,6 @@ def __call__(
num_channels_latents = self.transformer.config.in_channels
window_num_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1
num_latent_sections = max(1, (num_frames + window_num_frames - 1) // window_num_frames)
# Specific to the released checkpoint: https://huggingface.co/lllyasviel/FramePackI2V_HY
# TODO: find a more generic way in future if there are more checkpoints
history_sizes = [1, 2, 16]
history_latents = torch.zeros(
batch_size,
num_channels_latents,
sum(history_sizes),
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
device=device,
dtype=torch.float32,
)
history_video = None
total_generated_latent_frames = 0

Expand All @@ -829,38 +849,92 @@ def __call__(
last_image, dtype=torch.float32, device=device, generator=generator
)

latent_paddings = list(reversed(range(num_latent_sections)))
if num_latent_sections > 4:
latent_paddings = [3] + [2] * (num_latent_sections - 3) + [1, 0]
# Specific to the released checkpoints:
# - https://huggingface.co/lllyasviel/FramePackI2V_HY
# - https://huggingface.co/lllyasviel/FramePack_F1_I2V_HY_20250503
# TODO: find a more generic way in future if there are more checkpoints
if inference_type == FramepackInferenceType.INVERTED_ANTI_DRIFTING:
history_sizes = [1, 2, 16]
history_latents = torch.zeros(
batch_size,
num_channels_latents,
sum(history_sizes),
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
device=device,
dtype=torch.float32,
)

elif inference_type == FramepackInferenceType.VANILLA:
history_sizes = [16, 2, 1]
history_latents = torch.zeros(
batch_size,
num_channels_latents,
sum(history_sizes),
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
device=device,
dtype=torch.float32,
)
history_latents = torch.cat([history_latents, image_latents], dim=2)
total_generated_latent_frames += 1

else:
assert False

# 6. Prepare guidance condition
guidance = torch.tensor([guidance_scale] * batch_size, dtype=transformer_dtype, device=device) * 1000.0

# 7. Denoising loop
for k in range(num_latent_sections):
is_first_section = k == 0
is_last_section = k == num_latent_sections - 1
latent_padding_size = latent_paddings[k] * latent_window_size

indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes]))
(
indices_prefix,
indices_padding,
indices_latents,
indices_postfix,
indices_latents_history_2x,
indices_latents_history_4x,
) = indices.split([1, latent_padding_size, latent_window_size, *history_sizes], dim=0)
# Inverted anti-drifting sampling: Figure 2(c) in the paper
indices_clean_latents = torch.cat([indices_prefix, indices_postfix], dim=0)

latents_prefix = image_latents
latents_postfix, latents_history_2x, latents_history_4x = history_latents[
:, :, : sum(history_sizes)
].split(history_sizes, dim=2)
if last_image is not None and is_first_section:
latents_postfix = last_image_latents
latents_clean = torch.cat([latents_prefix, latents_postfix], dim=2)
if inference_type == FramepackInferenceType.INVERTED_ANTI_DRIFTING:
latent_paddings = list(reversed(range(num_latent_sections)))
if num_latent_sections > 4:
latent_paddings = [3] + [2] * (num_latent_sections - 3) + [1, 0]

is_first_section = k == 0
is_last_section = k == num_latent_sections - 1
latent_padding_size = latent_paddings[k] * latent_window_size

indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes]))
(
indices_prefix,
indices_padding,
indices_latents,
indices_latents_history_1x,
indices_latents_history_2x,
indices_latents_history_4x,
) = indices.split([1, latent_padding_size, latent_window_size, *history_sizes], dim=0)
# Inverted anti-drifting sampling: Figure 2(c) in the paper
indices_clean_latents = torch.cat([indices_prefix, indices_latents_history_1x], dim=0)

latents_prefix = image_latents
latents_history_1x, latents_history_2x, latents_history_4x = history_latents[
:, :, : sum(history_sizes)
].split(history_sizes, dim=2)
if last_image is not None and is_first_section:
latents_history_1x = last_image_latents
latents_clean = torch.cat([latents_prefix, latents_history_1x], dim=2)

elif inference_type == FramepackInferenceType.VANILLA:
indices = torch.arange(0, sum([1, *history_sizes, latent_window_size]))
(
indices_prefix,
indices_latents_history_4x,
indices_latents_history_2x,
indices_latents_history_1x,
indices_latents,
) = indices.split([1, *history_sizes, latent_window_size], dim=0)
indices_clean_latents = torch.cat([indices_prefix, indices_latents_history_1x], dim=0)

latents_prefix = image_latents
latents_history_4x, latents_history_2x, latents_history_1x = history_latents[
:, :, : -sum(history_sizes)
].split(history_sizes, dim=2)
latents_clean = torch.cat([latents_prefix, latents_history_1x], dim=2)

else:
assert False

latents = self.prepare_latents(
batch_size,
Expand Down Expand Up @@ -964,9 +1038,24 @@ def __call__(
latents = torch.cat([image_latents, latents], dim=2)

total_generated_latent_frames += latents.shape[2]
history_latents = torch.cat([latents, history_latents], dim=2)
overlapped_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1

if inference_type == FramepackInferenceType.INVERTED_ANTI_DRIFTING:
history_latents = torch.cat([latents, history_latents], dim=2)
real_history_latents = history_latents[:, :, :total_generated_latent_frames]
section_latent_frames = (
(latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
)
index_slice = (slice(None), slice(None), slice(0, section_latent_frames))

elif inference_type == FramepackInferenceType.VANILLA:
history_latents = torch.cat([history_latents, latents], dim=2)
real_history_latents = history_latents[:, :, -total_generated_latent_frames:]
section_latent_frames = latent_window_size * 2
index_slice = (slice(None), slice(None), slice(-section_latent_frames, None))

real_history_latents = history_latents[:, :, :total_generated_latent_frames]
else:
assert False

if history_video is None:
if not output_type == "latent":
Expand All @@ -976,16 +1065,17 @@ def __call__(
history_video = [real_history_latents]
else:
if not output_type == "latent":
section_latent_frames = (
(latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
)
overlapped_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1
current_latents = (
real_history_latents[:, :, :section_latent_frames].to(vae_dtype)
/ self.vae.config.scaling_factor
real_history_latents[index_slice].to(vae_dtype) / self.vae.config.scaling_factor
)
current_video = self.vae.decode(current_latents, return_dict=False)[0]
history_video = self._soft_append(current_video, history_video, overlapped_frames)

if inference_type == FramepackInferenceType.INVERTED_ANTI_DRIFTING:
history_video = self._soft_append(current_video, history_video, overlapped_frames)
elif inference_type == FramepackInferenceType.VANILLA:
history_video = self._soft_append(history_video, current_video, overlapped_frames)
else:
assert False
else:
history_video.append(real_history_latents)

Expand Down