Skip to content

Commit 3f57538

Browse files
committed
kind of added lip loss
1 parent 157cf07 commit 3f57538

File tree

15 files changed

+1980
-218
lines changed

15 files changed

+1980
-218
lines changed
Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
model:
2+
base_learning_rate: 3.e-5
3+
target: sgm.models.diffusion.DiffusionEngine
4+
params:
5+
input_key: latents
6+
no_log_keys: [audio_emb, fps_id, motion_bucket_id, cond_aug]
7+
scale_factor: 0.18215
8+
disable_first_stage_autocast: True
9+
ckpt_path: logs/2024-05-28T11-10-27_example_training-svd_interpolation_no_emb/checkpoints/last.ckpt/checkpoint/mp_rank_00_model_states.pt
10+
remove_keys_from_weights: []
11+
compile_model: False
12+
en_and_decode_n_samples_a_time: 1
13+
# optimizer_config:
14+
# target: deepspeed.ops.adam.DeepSpeedCPUAdam
15+
16+
scheduler_config:
17+
target: sgm.lr_scheduler.LambdaLinearScheduler
18+
params:
19+
warm_up_steps: [1000]
20+
cycle_lengths: [10000000000000]
21+
f_start: [1.e-6]
22+
f_max: [1.]
23+
f_min: [1.]
24+
25+
to_freeze: []
26+
to_unfreeze: []
27+
28+
# LoRA
29+
use_lora: False
30+
lora_config:
31+
search_class_str: Linear
32+
target_replace_module: null
33+
r_linear: 16
34+
r_conv: 16
35+
loras: null # path to lora .pt
36+
# verbose: False
37+
# dropout_p: 0.0
38+
# scale: 1.0
39+
# search_class: both
40+
41+
denoiser_config:
42+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
43+
params:
44+
scaling_config:
45+
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
46+
47+
# network_wrapper: sgm.modules.diffusionmodules.wrappers.IdentityWrapper
48+
network_wrapper:
49+
target: sgm.modules.diffusionmodules.wrappers.InterpolationWrapper
50+
params:
51+
im_size: [512, 512] # USER: adapt this to your dataset
52+
n_channels: 4
53+
starting_mask_method: zeros
54+
add_mask: True
55+
56+
network_config:
57+
target: sgm.modules.diffusionmodules.video_model.VideoUNet
58+
params:
59+
adm_in_channels: 0
60+
num_classes: sequential
61+
use_checkpoint: True
62+
in_channels: 9
63+
out_channels: 4
64+
model_channels: 320
65+
attention_resolutions: [4, 2, 1]
66+
num_res_blocks: 2
67+
channel_mult: [1, 2, 4, 4]
68+
num_head_channels: 64
69+
use_linear_in_transformer: True
70+
transformer_depth: 1
71+
context_dim: 1024
72+
spatial_transformer_attn_type: softmax-xformers
73+
extra_ff_mix_layer: True
74+
use_spatial_context: True
75+
merge_strategy: learned_with_images
76+
video_kernel_size: [3, 1, 1]
77+
fine_tuning_method: null
78+
audio_cond_method: to_time_emb
79+
additional_audio_frames: 0
80+
audio_dim: 768
81+
unfreeze_blocks: ["input"] # Because we changed the input block
82+
# adapter_kwargs:
83+
# # down_ratio: 1
84+
# # adapter_type: null
85+
# # adapter_weight: null
86+
# # act_layer: gelu
87+
# # zero_init_last: True
88+
# # use_bias: True
89+
# # adapt_on_time: True
90+
# # condition_on: space
91+
# # condition_dim: 1280
92+
# target_replace_module: ["SpatialVideoTransformer"]
93+
# r: 16
94+
# loras: null # path to lora .pt
95+
# verbose: False
96+
# dropout_p: 0.0
97+
# scale: 1.0
98+
99+
conditioner_config:
100+
target: sgm.modules.GeneralConditioner
101+
params:
102+
emb_models:
103+
- is_trainable: False
104+
input_key: cond_frames_without_noise
105+
ucg_rate: 0.1
106+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
107+
params:
108+
n_cond_frames: 2
109+
n_copies: 1
110+
open_clip_embedding_config:
111+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
112+
params:
113+
freeze: True
114+
115+
# - input_key: fps_id
116+
# is_trainable: False
117+
# target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
118+
# params:
119+
# outdim: 256
120+
121+
# - input_key: motion_bucket_id
122+
# is_trainable: False
123+
# target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
124+
# params:
125+
# outdim: 256
126+
127+
- input_key: cond_frames
128+
is_trainable: False
129+
ucg_rate: 0.1
130+
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
131+
params:
132+
disable_encoder_autocast: True
133+
n_cond_frames: 2
134+
n_copies: 1
135+
is_ae: True
136+
load_encoder: False
137+
encoder_config:
138+
target: sgm.models.autoencoder.AutoencoderKLModeOnly
139+
params:
140+
embed_dim: 4
141+
monitor: val/rec_loss
142+
ddconfig:
143+
attn_type: vanilla-xformers
144+
double_z: True
145+
z_channels: 4
146+
resolution: 256
147+
in_channels: 3
148+
out_ch: 3
149+
ch: 128
150+
ch_mult: [1, 2, 4, 4]
151+
num_res_blocks: 2
152+
attn_resolutions: []
153+
dropout: 0.0
154+
lossconfig:
155+
target: torch.nn.Identity
156+
157+
# - input_key: cond_aug
158+
# is_trainable: False
159+
# target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
160+
# params:
161+
# outdim: 256
162+
163+
- input_key: audio_emb
164+
is_trainable: True
165+
ucg_rate: 0.2
166+
target: sgm.modules.encoders.modules.WhisperAudioEmbedder
167+
params:
168+
merge_method: mean
169+
linear_dim: null
170+
171+
first_stage_config:
172+
target: sgm.models.autoencoder.AutoencodingEngine
173+
params:
174+
loss_config:
175+
target: torch.nn.Identity
176+
regularizer_config:
177+
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
178+
encoder_config:
179+
target: sgm.modules.diffusionmodules.model.Encoder
180+
params:
181+
attn_type: vanilla
182+
double_z: True
183+
z_channels: 4
184+
resolution: 256
185+
in_channels: 3
186+
out_ch: 3
187+
ch: 128
188+
ch_mult: [1, 2, 4, 4]
189+
num_res_blocks: 2
190+
attn_resolutions: []
191+
dropout: 0.0
192+
decoder_config:
193+
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
194+
params:
195+
attn_type: vanilla
196+
double_z: True
197+
z_channels: 4
198+
resolution: 256
199+
in_channels: 3
200+
out_ch: 3
201+
ch: 128
202+
ch_mult: [1, 2, 4, 4]
203+
num_res_blocks: 2
204+
attn_resolutions: []
205+
dropout: 0.0
206+
video_kernel_size: [3, 1, 1]
207+
208+
sampler_config:
209+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
210+
params:
211+
num_steps: 10
212+
discretization_config:
213+
target: sgm.modules.diffusionmodules.discretizer.AYSDiscretization
214+
# params:
215+
# # sigma_max: 700.0
216+
217+
guider_config:
218+
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
219+
params:
220+
max_scale: 2.5
221+
min_scale: 1.0
222+
num_frames: 14
223+
224+
loss_fn_config:
225+
target: sgm.modules.diffusionmodules.loss.StandardWithLipLoss
226+
params:
227+
lambda_lower: 1.
228+
weight_path: /data/home/antoni/code/generative-models/checkpoints/vsr_trlrs3_base.max400.pth
229+
batch2model_keys:
230+
- image_only_indicator
231+
- num_video_frames
232+
loss_weighting_config:
233+
# target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
234+
target: sgm.modules.diffusionmodules.loss_weighting.VWeighting
235+
sigma_sampler_config:
236+
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
237+
params:
238+
# p_mean: 0.7
239+
# p_std: 1.6
240+
p_mean: 1.
241+
p_std: 1.2
242+
243+
data:
244+
target: sgm.data.video_datamodule_latent.VideoDataModule
245+
params:
246+
train:
247+
datapipeline:
248+
# urls:
249+
# # USER: adapt this path the root of your custom dataset
250+
# - /data2/Datasets/LRW/webdata/train/out-{000000..000004}.tar
251+
# pipeline_config:
252+
# shardshuffle: 10000
253+
# sample_shuffle: 100 # USER: you might wanna adapt depending on your available RAM
254+
255+
# decoders:
256+
# - custom
257+
# postprocessors:
258+
# - target: sdata.mappers.SelectTuple
259+
# params:
260+
# key: 'mp4' # USER: you might wanna adapt this for your custom dataset
261+
# index: 0
262+
# - target: sdata.mappers.ToSVDFormat
263+
# params:
264+
# key: mp4
265+
# audio_key: pt
266+
# n_frames: 14
267+
# resize_size: 320
268+
# motion_id: 60
269+
# fps: 24 # FPS - 1 See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
270+
# cond_noise: [-3.0, 0.5]
271+
# mode: interpolation
272+
# filelist: /vol/paramonos2/projects/antoni/datasets/HDTF/filelist_videos_train.txt
273+
filelist: /fsx/rs2517/data/lists/HDTF/filelist_videos_train.txt
274+
resize_size: 512
275+
audio_folder: /fsx/rs2517/data/HDTF/audio
276+
video_folder: /fsx/rs2517/data/HDTF/cropped_videos_original
277+
lip_emb_folder: /fsx/antoni/data/HDTF/lipemb
278+
landmarks_folder: null
279+
video_extension: .mp4
280+
audio_extension: .wav
281+
latent_folder: null
282+
audio_in_video: False
283+
audio_rate: 16000
284+
num_frames: 14
285+
need_cond: True
286+
mode: interpolation
287+
use_latent: True
288+
latent_type: video
289+
latent_scale: 1 # For backwards compatibility
290+
from_audio_embedding: True
291+
load_all_possible_indexes: True
292+
audio_emb_type: wav2vec2
293+
# cond_noise: [-3.0, 0.5]
294+
cond_noise: 0.
295+
motion_id: 125
296+
data_mean: null
297+
data_std: null
298+
use_latent_condition: True
299+
get_lip_emb: True
300+
get_landmarks: True
301+
302+
loader:
303+
batch_size: 1
304+
num_workers: 6
305+
drop_last: True
306+
pin_memory: True
307+
persistent_workers: True
308+
# collation_fn:
309+
# target: sgm.data.collates.collate_video
310+
# params:
311+
# merge_keys: [frames]
312+
313+
# validation:
314+
315+
# datapipeline:
316+
# urls:
317+
# # USER: adapt this path the root of your custom dataset
318+
# - /data/122-2/Datasets/CREMA/webdataset/val/out-{000000..000001}.tar
319+
# pipeline_config:
320+
# shardshuffle: 10000
321+
# sample_shuffle: 1000 # USER: you might wanna adapt depending on your available RAM
322+
323+
# decoders:
324+
# - video
325+
# postprocessors:
326+
# - target: sdata.mappers.SelectTuple
327+
# params:
328+
# key: 'mp4' # USER: you might wanna adapt this for your custom dataset
329+
# index: 0
330+
# - target: sdata.mappers.ToSVDFormat
331+
# params:
332+
# key: mp4
333+
# n_frames: 14
334+
# resize_size: 256
335+
# cond_noise: [-3.0, 0.5]
336+
337+
# loader:
338+
# batch_size: 2
339+
# num_workers: 6
340+
341+
lightning:
342+
modelcheckpoint:
343+
params:
344+
every_n_train_steps: 5000
345+
save_top_k: 1
346+
347+
callbacks:
348+
metrics_over_trainsteps_checkpoint:
349+
params:
350+
every_n_train_steps: 25000
351+
352+
video_logger:
353+
target: sgm.callbacks.video_logger.VideoLogger
354+
params:
355+
disabled: False
356+
enable_autocast: False
357+
batch_frequency: 1000
358+
max_videos: 1
359+
increase_log_steps: False
360+
log_first_step: True
361+
log_videos_kwargs:
362+
ucg_keys: [cond_frames, cond_frames_without_noise, audio_emb]
363+
use_ema_scope: False
364+
N: 1
365+
n_rows: 1
366+
367+
trainer:
368+
devices: -1
369+
benchmark: False
370+
num_sanity_val_steps: 1
371+
accumulate_grad_batches: 1
372+
max_epochs: 1000
373+
precision: bf16-mixed
374+
num_nodes: 1

0 commit comments

Comments
 (0)