Skip to content

add SV4D 2.0 #440

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -12,4 +12,5 @@
/outputs
/build
/src
/.vscode
/.vscode
**/__pycache__/
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -5,6 +5,46 @@
## News


**April 4, 2025**
- We are releasing **[Stable Video 4D 2.0 (SV4D 2.0)](https://huggingface.co/stabilityai/sv4d2.0)**, an enhanced video-to-4D diffusion model for high-fidelity novel-view video synthesis and 4D asset generation. For research purposes:
- **SV4D 2.0** was trained to generate 48 frames (12 video frames x 4 camera views) at 576x576 resolution, given a 12-frame input video of the same size, ideally consisting of white-background images of a moving object.
- Compared to our previous 4D model [SV4D](https://huggingface.co/stabilityai/sv4d), **SV4D 2.0** can generate videos with higher fidelity, sharper details during motion, and better spatio-temporal consistency. It also generalizes much better to real-world videos. Moreover, it does not rely on refernce multi-view of the first frame generated by SV3D, making it more robust to self-occlusions.
- To generate longer novel-view videos, we autoregressively generate 12 frames at a time and use the previous generation as conditioning views for the remaining frames.
- Please check our [project page](https://sv4d20.github.io), [arxiv paper](https://arxiv.org/pdf/2503.16396) and [video summary](https://www.youtube.com/watch?v=dtqj-s50ynU) for more details.

**QUICKSTART** :
- `python scripts/sampling/simple_video_sample_4d2.py --input_path assets/sv4d_videos/camel.gif --output_folder outputs` (after downloading [sv4d2.safetensors](https://huggingface.co/stabilityai/sv4d2.0) from HuggingFace into `checkpoints/`)

To run **SV4D 2.0** on a single input video of 21 frames:
- Download SV4D 2.0 model (`sv4d2.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d2.0) to `checkpoints/`: `huggingface-cli download stabilityai/sv4d2.0 sv4d2.safetensors --local-dir checkpoints`
- Run inference: `python scripts/sampling/simple_video_sample_4d2.py --input_path <path/to/video>`
- `input_path` : The input video `<path/to/video>` can be
- a single video file in `gif` or `mp4` format, such as `assets/sv4d_videos/camel.gif`, or
- a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or
- a file name pattern matching images of video frames.
- `num_steps` : default is 50, can decrease to it to shorten sampling time.
- `elevations_deg` : specified elevations (reletive to input view), default is 0.0 (same as input view).
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Clipdrop](https://clipdrop.co/) or [SAM2](https://github.com/facebookresearch/segment-anything-2) before running SV4D.
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--encoding_t=1` (of frames encoded at a time) and `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.

Notes:
- We also train a 8-view model that generates 5 frames x 8 views at a time (same as SV4D).
- Download the model from huggingface: `huggingface-cli download stabilityai/sv4d2.0 sv4d2_8views.safetensors --local-dir checkpoints`
- Run inference: `python scripts/sampling/simple_video_sample_4d2.py --model_path checkpoints/sv4d2_8views.safetensors --input_path assets/sv4d_videos/chest.gif --output_folder outputs`
- The 5x8 model takes 5 frames of input at a time. But the inference scripts for both model take 21-frame video as input by default (same as SV3D and SV4D), we run the model autoregressively until we generate 21 frames.
- Install dependencies before running:
```
python3.10 -m venv .generativemodels
source .generativemodels/bin/activate
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # check CUDA version
pip3 install -r requirements/pt2.txt
pip3 install .
pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
```

![tile](assets/sv4d2.gif)


**July 24, 2024**
- We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes:
- **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object.
@@ -164,6 +204,7 @@ This is assuming you have navigated to the `generative-models` root after clonin
# install required packages from pypi
python3 -m venv .pt2
source .pt2/bin/activate
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip3 install -r requirements/pt2.txt
```

Binary file added assets/sv4d2.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/sv4d_videos/bear.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/sv4d_videos/bee.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/sv4d_videos/bmx-bumps.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed assets/sv4d_videos/bunnyman.mp4
Binary file not shown.
Binary file added assets/sv4d_videos/camel.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/sv4d_videos/chameleon.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/sv4d_videos/chest.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/sv4d_videos/cows.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/sv4d_videos/dance-twirl.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed assets/sv4d_videos/dolphin.mp4
Binary file not shown.
Binary file added assets/sv4d_videos/flag.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/sv4d_videos/gear.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed assets/sv4d_videos/green_robot.mp4
Binary file not shown.
Binary file removed assets/sv4d_videos/guppie_v0.mp4
Binary file not shown.
Binary file added assets/sv4d_videos/hike.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed assets/sv4d_videos/hiphop_parrot.mp4
Binary file not shown.
Binary file added assets/sv4d_videos/horsejump-low.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed assets/sv4d_videos/human5.mp4
Binary file not shown.
Binary file removed assets/sv4d_videos/human7.mp4
Binary file not shown.
Binary file removed assets/sv4d_videos/lucia_v000.mp4
Binary file not shown.
Binary file removed assets/sv4d_videos/monkey.mp4
Binary file not shown.
Binary file removed assets/sv4d_videos/pistol_v0.mp4
Binary file not shown.
Binary file added assets/sv4d_videos/robot.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/sv4d_videos/snowboard.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed assets/sv4d_videos/snowboard_v000.mp4
Binary file not shown.
Binary file removed assets/sv4d_videos/stroller_v000.mp4
Binary file not shown.
Binary file removed assets/sv4d_videos/test_video2.mp4
Binary file not shown.
Binary file removed assets/sv4d_videos/train_v0.mp4
Binary file not shown.
Binary file removed assets/sv4d_videos/wave_hello.mp4
Binary file not shown.
Binary file added assets/sv4d_videos/windmill.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 4 additions & 1 deletion requirements/pt2.txt
Original file line number Diff line number Diff line change
@@ -5,13 +5,16 @@ einops>=0.6.1
fairscale>=0.4.13
fire>=0.5.0
fsspec>=2023.6.0
imageio[ffmpeg]
imageio[pyav]
invisible-watermark>=0.2.0
kornia==0.6.9
matplotlib>=3.7.2
natsort>=8.4.0
ninja>=1.11.1
numpy>=1.24.4
numpy==2.1
omegaconf>=2.3.0
onnxruntime
open-clip-torch>=2.20.0
opencv-python==4.6.0.66
pandas>=2.0.3
250 changes: 187 additions & 63 deletions scripts/demo/sv4d_helpers.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions scripts/sampling/configs/sv4d.yaml
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ model:
attention_resolutions: [4, 2, 1]
channel_mult: [1, 2, 4, 4]
context_dim: 1024
motion_context_dim: 4
extra_ff_mix_layer: True
in_channels: 8
legacy: False
208 changes: 208 additions & 0 deletions scripts/sampling/configs/sv4d2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
N_TIME: 12
N_VIEW: 4
N_FRAMES: 48

model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
en_and_decode_n_samples_a_time: 8
disable_first_stage_autocast: True
ckpt_path: checkpoints/sv4d2.safetensors
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise

network_config:
target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime
params:
adm_in_channels: 1280
attention_resolutions: [4, 2, 1]
channel_mult: [1, 2, 4, 4]
context_dim: 1024
motion_context_dim: 4
extra_ff_mix_layer: True
in_channels: 8
legacy: False
model_channels: 320
num_classes: sequential
num_head_channels: 64
num_res_blocks: 2
out_channels: 4
replicate_time_mix_bug: True
spatial_transformer_attn_type: softmax-xformers
time_block_merge_factor: 0.0
time_block_merge_strategy: learned_with_images
time_kernel_size: [3, 1, 1]
time_mix_legacy: False
transformer_depth: 1
use_checkpoint: False
use_linear_in_transformer: True
use_spatial_context: True
use_spatial_transformer: True
separate_motion_merge_factor: True
use_motion_attention: True
use_3d_attention: True
use_camera_emb: True

conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:

- input_key: cond_frames_without_noise
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
is_trainable: False
params:
n_cond_frames: ${N_TIME}
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True

- input_key: cond_frames
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
is_trainable: False
params:
is_ae: True
n_cond_frames: ${N_FRAMES}
n_copies: 1
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
ddconfig:
attn_resolutions: []
attn_type: vanilla-xformers
ch: 128
ch_mult: [1, 2, 4, 4]
double_z: True
dropout: 0.0
in_channels: 3
num_res_blocks: 2
out_ch: 3
resolution: 256
z_channels: 4
embed_dim: 4
lossconfig:
target: torch.nn.Identity
monitor: val/rec_loss
sigma_cond_config:
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler

- input_key: polar_rad
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 512

- input_key: azimuth_rad
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 512

- input_key: cond_view
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
is_ae: True
n_cond_frames: ${N_VIEW}
n_copies: 1
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_resolutions: []
attn_type: vanilla-xformers
ch: 128
ch_mult: [1, 2, 4, 4]
double_z: True
dropout: 0.0
in_channels: 3
num_res_blocks: 2
out_ch: 3
resolution: 256
z_channels: 4
lossconfig:
target: torch.nn.Identity
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler

- input_key: cond_motion
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
is_ae: True
n_cond_frames: ${N_TIME}
n_copies: 1
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_resolutions: []
attn_type: vanilla-xformers
ch: 128
ch_mult: [1, 2, 4, 4]
double_z: True
dropout: 0.0
in_channels: 3
num_res_blocks: 2
out_ch: 3
resolution: 256
z_channels: 4
lossconfig:
target: torch.nn.Identity
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler

first_stage_config:
target: sgm.models.autoencoder.AutoencodingEngine
params:
loss_config:
target: torch.nn.Identity
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: torch.nn.Identity
decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params:
attn_resolutions: []
attn_type: vanilla-xformers
ch: 128
ch_mult: [1, 2, 4, 4]
double_z: True
dropout: 0.0
in_channels: 3
num_res_blocks: 2
out_ch: 3
resolution: 256
z_channels: 4

sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
num_steps: 50
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
params:
sigma_max: 500.0
guider_config:
target: sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider
params:
max_scale: 1.5
min_scale: 1.5
num_frames: ${N_FRAMES}
num_views: ${N_VIEW}
additional_cond_keys: [ cond_view, cond_motion ]
208 changes: 208 additions & 0 deletions scripts/sampling/configs/sv4d2_8views.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
N_TIME: 5
N_VIEW: 8
N_FRAMES: 40

model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
en_and_decode_n_samples_a_time: 8
disable_first_stage_autocast: True
ckpt_path: checkpoints/sv4d2_8views.safetensors
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise

network_config:
target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime
params:
adm_in_channels: 1280
attention_resolutions: [4, 2, 1]
channel_mult: [1, 2, 4, 4]
context_dim: 1024
motion_context_dim: 4
extra_ff_mix_layer: True
in_channels: 8
legacy: False
model_channels: 320
num_classes: sequential
num_head_channels: 64
num_res_blocks: 2
out_channels: 4
replicate_time_mix_bug: True
spatial_transformer_attn_type: softmax-xformers
time_block_merge_factor: 0.0
time_block_merge_strategy: learned_with_images
time_kernel_size: [3, 1, 1]
time_mix_legacy: False
transformer_depth: 1
use_checkpoint: False
use_linear_in_transformer: True
use_spatial_context: True
use_spatial_transformer: True
separate_motion_merge_factor: True
use_motion_attention: True
use_3d_attention: False
use_camera_emb: True

conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:

- input_key: cond_frames_without_noise
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
is_trainable: False
params:
n_cond_frames: ${N_TIME}
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True

- input_key: cond_frames
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
is_trainable: False
params:
is_ae: True
n_cond_frames: ${N_FRAMES}
n_copies: 1
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
ddconfig:
attn_resolutions: []
attn_type: vanilla-xformers
ch: 128
ch_mult: [1, 2, 4, 4]
double_z: True
dropout: 0.0
in_channels: 3
num_res_blocks: 2
out_ch: 3
resolution: 256
z_channels: 4
embed_dim: 4
lossconfig:
target: torch.nn.Identity
monitor: val/rec_loss
sigma_cond_config:
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler

- input_key: polar_rad
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 512

- input_key: azimuth_rad
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 512

- input_key: cond_view
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
is_ae: True
n_cond_frames: ${N_VIEW}
n_copies: 1
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_resolutions: []
attn_type: vanilla-xformers
ch: 128
ch_mult: [1, 2, 4, 4]
double_z: True
dropout: 0.0
in_channels: 3
num_res_blocks: 2
out_ch: 3
resolution: 256
z_channels: 4
lossconfig:
target: torch.nn.Identity
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler

- input_key: cond_motion
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
is_ae: True
n_cond_frames: ${N_TIME}
n_copies: 1
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_resolutions: []
attn_type: vanilla-xformers
ch: 128
ch_mult: [1, 2, 4, 4]
double_z: True
dropout: 0.0
in_channels: 3
num_res_blocks: 2
out_ch: 3
resolution: 256
z_channels: 4
lossconfig:
target: torch.nn.Identity
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler

first_stage_config:
target: sgm.models.autoencoder.AutoencodingEngine
params:
loss_config:
target: torch.nn.Identity
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: torch.nn.Identity
decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params:
attn_resolutions: []
attn_type: vanilla-xformers
ch: 128
ch_mult: [1, 2, 4, 4]
double_z: True
dropout: 0.0
in_channels: 3
num_res_blocks: 2
out_ch: 3
resolution: 256
z_channels: 4

sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
num_steps: 50
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
params:
sigma_max: 500.0
guider_config:
target: sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider
params:
max_scale: 2.0
min_scale: 1.5
num_frames: ${N_FRAMES}
num_views: ${N_VIEW}
additional_cond_keys: [ cond_view, cond_motion ]
5 changes: 3 additions & 2 deletions scripts/sampling/simple_video_sample.py
Original file line number Diff line number Diff line change
@@ -163,7 +163,7 @@ def sample(
else:
with Image.open(input_img_path) as image:
if image.mode == "RGBA":
input_image = image.convert("RGB")
image = image.convert("RGB")
w, h = image.size

if h % 64 != 0 or w % 64 != 0:
@@ -172,7 +172,8 @@ def sample(
print(
f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
)

input_image = np.array(image)

image = ToTensor()(input_image)
image = image * 2.0 - 1.0

16 changes: 9 additions & 7 deletions scripts/sampling/simple_video_sample_4d.py
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@


def sample(
input_path: str = "assets/test_video.mp4", # Can either be image file or folder with image files
input_path: str = "assets/sv4d_videos/test_video1.mp4", # Can either be image file or folder with image files
output_folder: Optional[str] = "outputs/sv4d",
num_steps: Optional[int] = 20,
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
@@ -71,7 +71,8 @@ def sample(
"f": F,
"options": {
"discretization": 1,
"cfg": 3.0,
"cfg": 2.0,
"num_views": V,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
@@ -94,6 +95,7 @@ def sample(

# Read input video frames i.e. images at view 0
print(f"Reading {input_path}")
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 11
processed_input_path = preprocess_video(
input_path,
remove_bg=remove_bg,
@@ -102,6 +104,7 @@ def sample(
H=H,
output_folder=output_folder,
image_frame_ratio=image_frame_ratio,
base_count=base_count,
)
images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device)

@@ -145,15 +148,14 @@ def sample(
for t in range(n_frames):
img_matrix[t][0] = images_v0[t]

base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 12
save_video(
os.path.join(output_folder, f"{base_count:06d}_t000.mp4"),
img_matrix[0],
)
save_video(
os.path.join(output_folder, f"{base_count:06d}_v000.mp4"),
[img_matrix[t][0] for t in range(n_frames)],
)
# save_video(
# os.path.join(output_folder, f"{base_count:06d}_v000.mp4"),
# [img_matrix[t][0] for t in range(n_frames)],
# )

# Load SV4D model
model, filter = load_model(
235 changes: 235 additions & 0 deletions scripts/sampling/simple_video_sample_4d2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
import os
import sys
from glob import glob
from typing import List, Optional

from tqdm import tqdm

sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
import numpy as np
import torch
from fire import Fire
from scripts.demo.sv4d_helpers import (
load_model,
preprocess_video,
read_video,
run_img2vid,
save_video,
)
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder

sv4d2_configs = {
"sv4d2": {
"T": 12, # number of frames per sample
"V": 4, # number of views per sample
"model_config": "scripts/sampling/configs/sv4d2.yaml",
"version_dict": {
"T": 12 * 4,
"options": {
"discretization": 1,
"cfg": 2.0,
"min_cfg": 2.0,
"num_views": 4,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
"guider": 2,
"force_uc_zero_embeddings": [
"cond_frames",
"cond_frames_without_noise",
"cond_view",
"cond_motion",
],
"additional_guider_kwargs": {
"additional_cond_keys": ["cond_view", "cond_motion"]
},
},
},
},
"sv4d2_8views": {
"T": 5, # number of frames per sample
"V": 8, # number of views per sample
"model_config": "scripts/sampling/configs/sv4d2_8views.yaml",
"version_dict": {
"T": 5 * 8,
"options": {
"discretization": 1,
"cfg": 2.5,
"min_cfg": 1.5,
"num_views": 8,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
"guider": 5,
"force_uc_zero_embeddings": [
"cond_frames",
"cond_frames_without_noise",
"cond_view",
"cond_motion",
],
"additional_guider_kwargs": {
"additional_cond_keys": ["cond_view", "cond_motion"]
},
},
},
},
}


def sample(
input_path: str = "assets/sv4d_videos/camel.gif", # Can either be image file or folder with image files
model_path: Optional[str] = "checkpoints/sv4d2.safetensors",
output_folder: Optional[str] = "outputs",
num_steps: Optional[int] = 50,
img_size: int = 576, # image resolution
n_frames: int = 21, # number of input and output video frames
seed: int = 23,
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device: str = "cuda",
elevations_deg: Optional[List[float]] = 0.0,
azimuths_deg: Optional[List[float]] = None,
image_frame_ratio: Optional[float] = 0.9,
verbose: Optional[bool] = False,
remove_bg: bool = False,
):
"""
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`.
"""
# Set model config
assert os.path.basename(model_path) in [
"sv4d2.safetensors",
"sv4d2_8views.safetensors",
]
sv4d2_model = os.path.splitext(os.path.basename(model_path))[0]
config = sv4d2_configs[sv4d2_model]
print(sv4d2_model, config)
T = config["T"]
V = config["V"]
model_config = config["model_config"]
version_dict = config["version_dict"]
F = 8 # vae factor to downsize image->latent
C = 4
H, W = img_size, img_size
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
subsampled_views = np.arange(n_views)
version_dict["H"] = H
version_dict["W"] = W
version_dict["C"] = C
version_dict["f"] = F
version_dict["options"]["num_steps"] = num_steps

torch.manual_seed(seed)
output_folder = os.path.join(output_folder, sv4d2_model)
os.makedirs(output_folder, exist_ok=True)

# Read input video frames i.e. images at view 0
print(f"Reading {input_path}")
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // n_views
processed_input_path = preprocess_video(
input_path,
remove_bg=remove_bg,
n_frames=n_frames,
W=W,
H=H,
output_folder=output_folder,
image_frame_ratio=image_frame_ratio,
base_count=base_count,
)
images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device)
images_t0 = torch.zeros(n_views, 3, H, W).float().to(device)

# Get camera viewpoints
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
elevations_deg = [elevations_deg] * n_views
assert (
len(elevations_deg) == n_views
), f"Please provide 1 value, or a list of {n_views} values for elevations_deg! Given {len(elevations_deg)}"
if azimuths_deg is None:
# azimuths_deg = np.linspace(0, 360, n_views + 1)[1:] % 360
azimuths_deg = (
np.array([0, 60, 120, 180, 240])
if sv4d2_model == "sv4d2"
else np.array([0, 30, 75, 120, 165, 210, 255, 300, 330])
)
assert (
len(azimuths_deg) == n_views
), f"Please provide a list of {n_views} values for azimuths_deg! Given {len(azimuths_deg)}"
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
azimuths_rad = np.array(
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
)

# Initialize image matrix
img_matrix = [[None] * n_views for _ in range(n_frames)]
for i, v in enumerate(subsampled_views):
img_matrix[0][i] = images_t0[v].unsqueeze(0)
for t in range(n_frames):
img_matrix[t][0] = images_v0[t]

# Load SV4D++ model
model, _ = load_model(
model_config,
device,
version_dict["T"],
num_steps,
verbose,
model_path,
)
model.en_and_decode_n_samples_a_time = decoding_t
for emb in model.conditioner.embedders:
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
emb.en_and_decode_n_samples_a_time = encoding_t

# Sampling novel-view videos
v0 = 0
view_indices = np.arange(V) + 1
t0_list = (
range(0, n_frames, T)
if sv4d2_model == "sv4d2"
else range(0, n_frames - T + 1, T - 1)
)
for t0 in tqdm(t0_list):
if t0 + T > n_frames:
t0 = n_frames - T
frame_indices = t0 + np.arange(T)
print(f"Sampling frames {frame_indices}")
image = img_matrix[t0][v0]
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
polars = (polars - polars_rad[v0] + torch.pi / 2) % (torch.pi * 2)
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
cond_mv = False if t0 == 0 else True
samples = run_img2vid(
version_dict,
model,
image,
seed,
polars,
azims,
cond_motion,
cond_view,
decoding_t,
cond_mv=cond_mv,
)
samples = samples.view(T, V, 3, H, W)

for i, t in enumerate(frame_indices):
for j, v in enumerate(view_indices):
img_matrix[t][v] = samples[i, j][None] * 2 - 1

# Save output videos
for v in view_indices:
vid_file = os.path.join(output_folder, f"{base_count:06d}_v{v:03d}.mp4")
print(f"Saving {vid_file}")
save_video(
vid_file,
[img_matrix[t][v] for t in range(n_frames) if img_matrix[t][v] is not None],
)


if __name__ == "__main__":
Fire(sample)
24 changes: 21 additions & 3 deletions sgm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
@@ -74,6 +74,7 @@ def forward(
x: th.Tensor,
emb: th.Tensor,
context: Optional[th.Tensor] = None,
cam: Optional[th.Tensor] = None,
image_only_indicator: Optional[th.Tensor] = None,
cond_view: Optional[th.Tensor] = None,
cond_motion: Optional[th.Tensor] = None,
@@ -86,7 +87,7 @@ def forward(
from ...modules.spacetime_attention import (
BasicTransformerTimeMixBlock,
PostHocSpatialTransformerWithTimeMixing,
PostHocSpatialTransformerWithTimeMixingAndMotion
PostHocSpatialTransformerWithTimeMixingAndMotion,
)

for layer in self:
@@ -97,13 +98,30 @@ def forward(
(
BasicTransformerTimeMixBlock,
PostHocSpatialTransformerWithTimeMixing,
PostHocSpatialTransformerWithTimeMixingAndMotion
),
):
x = layer(
x,
context,
# cam,
emb,
time_context,
num_video_frames,
image_only_indicator,
cond_view,
cond_motion,
time_step,
name,
)
elif isinstance(
module,
(
PostHocSpatialTransformerWithTimeMixingAndMotion,
),
):
x = layer(
x,
context,
emb,
time_context,
num_video_frames,
image_only_indicator,
31 changes: 22 additions & 9 deletions sgm/modules/diffusionmodules/video_model.py
Original file line number Diff line number Diff line change
@@ -8,10 +8,10 @@
from ...modules.spacetime_attention import (
BasicTransformerTimeMixBlock,
PostHocSpatialTransformerWithTimeMixing,
PostHocSpatialTransformerWithTimeMixingAndMotion
PostHocSpatialTransformerWithTimeMixingAndMotion,
)
from ...util import default
from .util import AlphaBlender # , LegacyAlphaBlenderWithBug, get_alpha
from .util import AlphaBlender, get_alpha


class VideoResBlock(ResBlock):
@@ -716,11 +716,11 @@ def forward(
)

if self.time_mix_legacy:
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator*0.0)
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
else:
x = self.time_mixer(
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator*0.0
)
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
@@ -752,10 +752,14 @@ def __init__(
context_dim: Optional[int] = None,
time_downup: bool = False,
time_context_dim: Optional[int] = None,
view_context_dim: Optional[int] = None,
motion_context_dim: Optional[int] = None,
extra_ff_mix_layer: bool = False,
use_spatial_context: bool = False,
time_block_merge_strategy: str = "fixed",
time_block_merge_factor: float = 0.5,
view_block_merge_factor: float = 0.5,
motion_block_merge_factor: float = 0.5,
spatial_transformer_attn_type: str = "softmax",
time_kernel_size: Union[int, List[int]] = 3,
use_linear_in_transformer: bool = False,
@@ -767,6 +771,9 @@ def __init__(
max_ddpm_temb_period: int = 10000,
replicate_time_mix_bug: bool = False,
use_motion_attention: bool = False,
use_camera_emb: bool = False,
use_3d_attention: bool = False,
separate_motion_merge_factor: bool = False,
):
super().__init__()

@@ -886,11 +893,17 @@ def get_attention_layer(
depth=depth,
context_dim=context_dim,
time_context_dim=time_context_dim,
motion_context_dim=motion_context_dim,
dropout=dropout,
ff_in=extra_ff_mix_layer,
use_spatial_context=use_spatial_context,
use_camera_emb=use_camera_emb,
use_3d_attention=use_3d_attention,
separate_motion_merge_factor=separate_motion_merge_factor,
adm_in_channels=adm_in_channels,
merge_strategy=time_block_merge_strategy,
merge_factor=time_block_merge_factor,
merge_factor_motion=motion_block_merge_factor,
checkpoint=use_checkpoint,
use_linear=use_linear_in_transformer,
attn_mode=spatial_transformer_attn_type,
@@ -899,7 +912,7 @@ def get_attention_layer(
time_mix_legacy=time_mix_legacy,
max_time_embed_period=max_ddpm_temb_period,
)

else:
return PostHocSpatialTransformerWithTimeMixing(
ch,
@@ -1173,7 +1186,7 @@ def forward(
timesteps: th.Tensor,
context: Optional[th.Tensor] = None,
y: Optional[th.Tensor] = None,
# cam: Optional[th.Tensor] = None,
cam: Optional[th.Tensor] = None,
time_context: Optional[th.Tensor] = None,
num_video_frames: Optional[int] = None,
image_only_indicator: Optional[th.Tensor] = None,
@@ -1199,7 +1212,7 @@ def forward(
h,
emb,
context=context,
# cam=cam,
cam=cam,
image_only_indicator=image_only_indicator,
cond_view=cond_view,
cond_motion=cond_motion,
@@ -1213,7 +1226,7 @@ def forward(
h,
emb,
context=context,
# cam=cam,
cam=cam,
image_only_indicator=image_only_indicator,
cond_view=cond_view,
cond_motion=cond_motion,
@@ -1228,7 +1241,7 @@ def forward(
h,
emb,
context=context,
# cam=cam,
cam=cam,
image_only_indicator=image_only_indicator,
cond_view=cond_view,
cond_motion=cond_motion,
89 changes: 59 additions & 30 deletions sgm/modules/spacetime_attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial

import torch
import torch.nn.functional as F

from ..modules.attention import *
from ..modules.diffusionmodules.util import (
@@ -359,11 +360,17 @@ def __init__(
use_linear=False,
context_dim=None,
use_spatial_context=False,
use_camera_emb=False,
use_3d_attention=False,
separate_motion_merge_factor=False,
adm_in_channels=None,
timesteps=None,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
merge_factor_motion: float = 0.5,
apply_sigmoid_to_merge: bool = True,
time_context_dim=None,
motion_context_dim=None,
ff_in=False,
checkpoint=False,
time_depth=1,
@@ -388,6 +395,10 @@ def __init__(
self.time_depth = time_depth
self.depth = depth
self.max_time_embed_period = max_time_embed_period
self.use_camera_emb = use_camera_emb
self.motion_context_dim = motion_context_dim
self.use_3d_attention = use_3d_attention
self.separate_motion_merge_factor = separate_motion_merge_factor

time_mix_d_head = d_head
n_time_mix_heads = n_heads
@@ -398,9 +409,6 @@ def __init__(
if use_spatial_context:
time_context_dim = context_dim

camera_context_dim = time_context_dim
motion_context_dim = 4 # time_context_dim

# Camera attention layer
self.time_mix_blocks = nn.ModuleList(
[
@@ -409,7 +417,7 @@ def __init__(
n_time_mix_heads,
time_mix_d_head,
dropout=dropout,
context_dim=camera_context_dim,
context_dim=time_context_dim,
timesteps=timesteps,
checkpoint=checkpoint,
ff_in=ff_in,
@@ -449,9 +457,10 @@ def __init__(
self.in_channels = in_channels

time_embed_dim = self.in_channels * 4
time_embed_channels = adm_in_channels if self.use_camera_emb else self.in_channels
# Camera view embedding
self.time_mix_time_embed = nn.Sequential(
linear(self.in_channels, time_embed_dim),
linear(time_embed_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, self.in_channels),
)
@@ -486,12 +495,16 @@ def __init__(
self.time_mixer = AlphaBlender(
alpha=merge_factor, merge_strategy=merge_strategy
)
if self.separate_motion_merge_factor:
self.time_mixer_motion = AlphaBlender(
alpha=merge_factor_motion, merge_strategy=merge_strategy
)

def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
# cam: Optional[torch.Tensor] = None,
cam: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
@@ -500,16 +513,19 @@ def forward(
time_step: Optional[int] = None,
name: Optional[str] = None,
) -> torch.Tensor:
_, _, h, w = x.shape
# context: b t 1024
# cond_view: b*v 4 h w
# cond_motion: b*t 4 h w
# image_only_indicator: b t*v
b, t, d1 = context.shape # CLIP
v, d2 = cond_view.shape[0]//b, cond_view.shape[1] # VAE
_, c, h, w = x.shape

x_in = x
spatial_context = None
if exists(context):
spatial_context = context

# cond_view: b v 4 h w
# cond_motion: b t 4 h w
b, t, d1 = context.shape # CLIP
v, d2 = cond_view.shape[0]//b, cond_view.shape[1] # VAE
cond_view = torch.nn.functional.interpolate(cond_view, size=(h,w), mode="bilinear") # b*v d h w
spatial_context = context[:,:,None].repeat(1,1,v,1).reshape(-1,1,d1) # (b*t*v) 1 d1
camera_context = context[:,:,None].repeat(1,1,h*w,1).reshape(-1,1,d1) # (b*t*h*w) 1 d1
@@ -518,10 +534,9 @@ def forward(
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c") # 21 x 4096 x 320
x = rearrange(x, "b c h w -> b (h w) c")
if self.use_linear:
x = self.proj_in(x)
c = x.shape[-1]

if self.time_mix_legacy:
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
@@ -536,19 +551,26 @@ def forward(
max_period=self.max_time_embed_period,
)
emb_time = self.time_mix_motion_embed(t_emb)
emb_time = emb_time[:, None, :] # b*t x 1 x 320
emb_time = emb_time[:, None, :] # b*t 1 c

num_views = torch.arange(v, device=x.device)
num_views = repeat(num_views, "t -> b t", b=b)
num_views = rearrange(num_views, "b t -> (b t)")
v_emb = timestep_embedding(
num_views,
self.in_channels,
repeat_only=False,
max_period=self.max_time_embed_period,
)
emb_view = self.time_mix_time_embed(v_emb)
emb_view = emb_view[:, None, :] # b*v x 1 x 320
if self.use_camera_emb:
emb_view = self.time_mix_time_embed(cam.view(b,t,v,-1)[:,0].reshape(b*v,-1))
emb_view = emb_view[:, None, :]
else:
num_views = torch.arange(v, device=x.device)
num_views = repeat(num_views, "t -> b t", b=b)
num_views = rearrange(num_views, "b t -> (b t)")
v_emb = timestep_embedding(
num_views,
self.in_channels,
repeat_only=False,
max_period=self.max_time_embed_period,
)
emb_view = self.time_mix_time_embed(v_emb)
emb_view = emb_view[:, None, :] # b*v 1 c

if self.use_3d_attention:
emb_view = emb_view.repeat(1, h*w, 1).view(-1,1,c) # b*v*h*w 1 c

for it_, (block, time_block, mot_block) in enumerate(
zip(self.transformer_blocks, self.time_mix_blocks, self.motion_blocks)
@@ -560,7 +582,10 @@ def forward(
)

# Camera attention
x = x.view(b, t, v, h*w, c).permute(0,2,1,3,4).reshape(b*v,-1,c) # b*v t*h*w c
if self.use_3d_attention:
x = x.view(b, t, v, h*w, c).permute(0,2,3,1,4).reshape(-1,t,c) # b*v*h*w t c
else:
x = x.view(b, t, v, h*w, c).permute(0,2,1,3,4).reshape(b*v,-1,c) # b*v t*h*w c
x_mix = x + emb_view
x_mix = time_block(x_mix, context=camera_context, timesteps=v)
if self.time_mix_legacy:
@@ -569,20 +594,24 @@ def forward(
x = self.time_mixer(
x_spatial=x,
x_temporal=x_mix,
image_only_indicator=image_only_indicator[:,:v],
image_only_indicator=torch.zeros_like(image_only_indicator[:,:1].repeat(1,x.shape[0]//b)),
)

# Motion attention
x = x.view(b, v, t, h*w, c).permute(0,2,1,3,4).reshape(b*t,-1,c) # b*t v*h*w c
if self.use_3d_attention:
x = x.view(b, v, h*w, t, c).permute(0,3,1,2,4).reshape(b*t,-1,c) # b*t v*h*w c
else:
x = x.view(b, v, t, h*w, c).permute(0,2,1,3,4).reshape(b*t,-1,c) # b*t v*h*w c
x_mix = x + emb_time
x_mix = mot_block(x_mix, context=motion_context, timesteps=t)
if self.time_mix_legacy:
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
else:
x = self.time_mixer(
motion_mixer = self.time_mixer_motion if self.separate_motion_merge_factor else self.time_mixer
x = motion_mixer(
x_spatial=x,
x_temporal=x_mix,
image_only_indicator=image_only_indicator[:,:t],
image_only_indicator=torch.zeros_like(image_only_indicator[:,:1].repeat(1,x.shape[0]//b)),
)

x = x.view(b, t, v, h*w, c).reshape(-1,h*w,c) # b*t*v h*w c
Loading