Skip to content

Commit 97db8ed

Browse files
author
Vikram Voleti
committed
Gradio updates
1 parent bdbae99 commit 97db8ed

File tree

4 files changed

+656
-5
lines changed

4 files changed

+656
-5
lines changed

scripts/demo/sv3d_helpers.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import matplotlib.pyplot as plt
44
import numpy as np
5+
from PIL import Image
56

67

78
def generate_dynamic_cycle_xy_values(
@@ -74,8 +75,9 @@ def gen_dynamic_loop(length=21, elev_deg=0):
7475
return np.roll(azim_rad, -1), np.roll(elev_rad, -1)
7576

7677

77-
def plot_3D(azim, polar, save_path, dynamic=True):
78-
os.makedirs(os.path.dirname(save_path), exist_ok=True)
78+
def plot_3D(azim, polar, save_path=None, dynamic=True):
79+
if save_path is not None:
80+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
7981
elev = np.deg2rad(90) - polar
8082
fig = plt.figure(figsize=(5, 5))
8183
ax = fig.add_subplot(projection="3d")
@@ -98,7 +100,20 @@ def plot_3D(azim, polar, save_path, dynamic=True):
98100
ax.scatter(xs[i + 1], ys[i + 1], zs[i + 1], s=100, color=col[i + 1])
99101
ax.scatter(xs[:1], ys[:1], zs[:1], s=120, facecolors="none", edgecolors="k")
100102
ax.scatter(xs[-1:], ys[-1:], zs[-1:], s=120, facecolors="none", edgecolors="k")
101-
ax.view_init(elev=30, azim=-20, roll=0)
102-
plt.savefig(save_path, bbox_inches="tight")
103+
ax.view_init(elev=40, azim=-20, roll=0)
104+
ax.xaxis.set_ticklabels([])
105+
ax.yaxis.set_ticklabels([])
106+
ax.zaxis.set_ticklabels([])
107+
if save_path is None:
108+
fig.canvas.draw()
109+
lst = list(fig.canvas.get_width_height())
110+
lst.append(3)
111+
image = Image.fromarray(
112+
np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(lst)
113+
)
114+
else:
115+
plt.savefig(save_path, bbox_inches="tight")
103116
plt.clf()
104117
plt.close()
118+
if save_path is None:
119+
return image

scripts/demo/sv3d_p_gradio.py

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
# Adding this at the very top of app.py to make 'generative-models' directory discoverable
2+
import os
3+
import sys
4+
5+
sys.path.append(os.path.dirname(__file__))
6+
7+
import random
8+
from glob import glob
9+
from pathlib import Path
10+
from typing import List, Optional
11+
12+
import cv2
13+
import gradio as gr
14+
import imageio
15+
import numpy as np
16+
import torch
17+
from einops import rearrange, repeat
18+
from huggingface_hub import hf_hub_download
19+
from PIL import Image
20+
from rembg import remove
21+
from scripts.demo.sv3d_helpers import gen_dynamic_loop, plot_3D
22+
from scripts.sampling.simple_video_sample import (
23+
get_batch,
24+
get_unique_embedder_keys_from_conditioner,
25+
load_model,
26+
)
27+
from sgm.inference.helpers import embed_watermark
28+
from torchvision.transforms import ToTensor
29+
30+
version = "sv3d_p" # replace with 'sv3d_p' or 'sv3d_u' for other models
31+
32+
# Define the repo, local directory and filename
33+
repo_id = "stabilityai/sv3d"
34+
filename = f"{version}.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors"
35+
local_dir = "checkpoints"
36+
local_ckpt_path = os.path.join(local_dir, filename)
37+
38+
# Check if the file already exists
39+
if not os.path.exists(local_ckpt_path):
40+
# If the file doesn't exist, download it
41+
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
42+
print("File downloaded.")
43+
else:
44+
print("File already exists. No need to download.")
45+
46+
device = "cuda"
47+
max_64_bit_int = 2**63 - 1
48+
49+
num_frames = 21
50+
num_steps = 50
51+
model_config = f"scripts/sampling/configs/{version}.yaml"
52+
53+
model, filter = load_model(
54+
model_config,
55+
device,
56+
num_frames,
57+
num_steps,
58+
)
59+
60+
61+
def gen_orbit(orbit, elev_deg):
62+
global polars_rad
63+
global azimuths_rad
64+
if orbit == "dynamic":
65+
azim_rad, elev_rad = gen_dynamic_loop(length=num_frames, elev_deg=elev_deg)
66+
polars_rad = np.deg2rad(90) - elev_rad
67+
azimuths_rad = azim_rad
68+
else:
69+
polars_rad = np.array([np.deg2rad(90 - elev_deg)] * num_frames)
70+
azimuths_rad = np.linspace(0, 2 * np.pi, num_frames + 1)[1:]
71+
72+
plot = plot_3D(
73+
azim=azimuths_rad,
74+
polar=polars_rad,
75+
save_path=None,
76+
dynamic=(orbit == "dynamic"),
77+
)
78+
return plot
79+
80+
81+
def sample(
82+
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
83+
seed: Optional[int] = None,
84+
randomize_seed: bool = True,
85+
orbit: str = "same elevation",
86+
elev_deg: float = 10.0,
87+
decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
88+
device: str = "cuda",
89+
output_folder: str = None,
90+
image_frame_ratio: Optional[float] = None,
91+
):
92+
"""
93+
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
94+
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
95+
"""
96+
if randomize_seed:
97+
seed = random.randint(0, max_64_bit_int)
98+
99+
torch.manual_seed(seed)
100+
101+
path = Path(input_path)
102+
all_img_paths = []
103+
if path.is_file():
104+
if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
105+
all_img_paths = [input_path]
106+
else:
107+
raise ValueError("Path is not valid image file.")
108+
elif path.is_dir():
109+
all_img_paths = sorted(
110+
[
111+
f
112+
for f in path.iterdir()
113+
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
114+
]
115+
)
116+
if len(all_img_paths) == 0:
117+
raise ValueError("Folder does not contain any images.")
118+
else:
119+
raise ValueError
120+
121+
for input_img_path in all_img_paths:
122+
123+
image = Image.open(input_img_path)
124+
if image.mode == "RGBA":
125+
pass
126+
else:
127+
# remove bg
128+
image.thumbnail([768, 768], Image.Resampling.LANCZOS)
129+
image = remove(image.convert("RGBA"), alpha_matting=True)
130+
131+
# resize object in frame
132+
image_arr = np.array(image)
133+
in_w, in_h = image_arr.shape[:2]
134+
ret, mask = cv2.threshold(
135+
np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY
136+
)
137+
x, y, w, h = cv2.boundingRect(mask)
138+
max_size = max(w, h)
139+
side_len = (
140+
int(max_size / image_frame_ratio) if image_frame_ratio is not None else in_w
141+
)
142+
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
143+
center = side_len // 2
144+
padded_image[
145+
center - h // 2 : center - h // 2 + h,
146+
center - w // 2 : center - w // 2 + w,
147+
] = image_arr[y : y + h, x : x + w]
148+
# resize frame to 576x576
149+
rgba = Image.fromarray(padded_image).resize((576, 576), Image.LANCZOS)
150+
# white bg
151+
rgba_arr = np.array(rgba) / 255.0
152+
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
153+
input_image = Image.fromarray((rgb * 255).astype(np.uint8))
154+
155+
image = ToTensor()(input_image)
156+
image = image * 2.0 - 1.0
157+
158+
image = image.unsqueeze(0).to(device)
159+
H, W = image.shape[2:]
160+
assert image.shape[1] == 3
161+
F = 8
162+
C = 4
163+
shape = (num_frames, C, H // F, W // F)
164+
if (H, W) != (576, 576) and "sv3d" in version:
165+
print(
166+
"WARNING: The conditioning frame you provided is not 576x576. This leads to suboptimal performance as model was only trained on 576x576."
167+
)
168+
169+
cond_aug = 1e-5
170+
171+
value_dict = {}
172+
value_dict["cond_aug"] = cond_aug
173+
value_dict["cond_frames_without_noise"] = image
174+
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
175+
value_dict["cond_aug"] = cond_aug
176+
177+
value_dict["polars_rad"] = polars_rad
178+
value_dict["azimuths_rad"] = azimuths_rad
179+
180+
output_folder = output_folder or f"outputs/gradio/{version}"
181+
cond_aug = 1e-5
182+
183+
with torch.no_grad():
184+
with torch.autocast(device):
185+
batch, batch_uc = get_batch(
186+
get_unique_embedder_keys_from_conditioner(model.conditioner),
187+
value_dict,
188+
[1, num_frames],
189+
T=num_frames,
190+
device=device,
191+
)
192+
c, uc = model.conditioner.get_unconditional_conditioning(
193+
batch,
194+
batch_uc=batch_uc,
195+
force_uc_zero_embeddings=[
196+
"cond_frames",
197+
"cond_frames_without_noise",
198+
],
199+
)
200+
201+
for k in ["crossattn", "concat"]:
202+
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
203+
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
204+
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
205+
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
206+
207+
randn = torch.randn(shape, device=device)
208+
209+
additional_model_inputs = {}
210+
additional_model_inputs["image_only_indicator"] = torch.zeros(
211+
2, num_frames
212+
).to(device)
213+
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
214+
215+
def denoiser(input, sigma, c):
216+
return model.denoiser(
217+
model.model, input, sigma, c, **additional_model_inputs
218+
)
219+
220+
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
221+
model.en_and_decode_n_samples_a_time = decoding_t
222+
samples_x = model.decode_first_stage(samples_z)
223+
samples_x[-1:] = value_dict["cond_frames_without_noise"]
224+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
225+
226+
os.makedirs(output_folder, exist_ok=True)
227+
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
228+
229+
imageio.imwrite(
230+
os.path.join(output_folder, f"{base_count:06d}.jpg"), input_image
231+
)
232+
233+
samples = embed_watermark(samples)
234+
samples = filter(samples)
235+
vid = (
236+
(rearrange(samples, "t c h w -> t h w c") * 255)
237+
.cpu()
238+
.numpy()
239+
.astype(np.uint8)
240+
)
241+
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
242+
imageio.mimwrite(video_path, vid)
243+
244+
return video_path, seed
245+
246+
247+
def resize_image(image_path, output_size=(576, 576)):
248+
image = Image.open(image_path)
249+
# Calculate aspect ratios
250+
target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
251+
image_aspect = image.width / image.height # Aspect ratio of the original image
252+
253+
# Resize then crop if the original image is larger
254+
if image_aspect > target_aspect:
255+
# Resize the image to match the target height, maintaining aspect ratio
256+
new_height = output_size[1]
257+
new_width = int(new_height * image_aspect)
258+
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
259+
# Calculate coordinates for cropping
260+
left = (new_width - output_size[0]) / 2
261+
top = 0
262+
right = (new_width + output_size[0]) / 2
263+
bottom = output_size[1]
264+
else:
265+
# Resize the image to match the target width, maintaining aspect ratio
266+
new_width = output_size[0]
267+
new_height = int(new_width / image_aspect)
268+
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
269+
# Calculate coordinates for cropping
270+
left = 0
271+
top = (new_height - output_size[1]) / 2
272+
right = output_size[0]
273+
bottom = (new_height + output_size[1]) / 2
274+
275+
# Crop the image
276+
cropped_image = resized_image.crop((left, top, right, bottom))
277+
278+
return cropped_image
279+
280+
281+
with gr.Blocks() as demo:
282+
gr.Markdown(
283+
"""# Demo for SV3D_p from Stability AI ([model](https://huggingface.co/stabilityai/sv3d), [news](https://stability.ai/news/introducing-stable-video-3d))
284+
#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/sv3d/blob/main/LICENSE)): generate 21 frames orbital video from a single image, at variable elevation and azimuth.
285+
Generation takes ~40s (for 50 steps) in an A100.
286+
"""
287+
)
288+
with gr.Row():
289+
with gr.Column():
290+
image = gr.Image(label="Upload your image", type="filepath")
291+
generate_btn = gr.Button("Generate")
292+
video = gr.Video()
293+
with gr.Row():
294+
with gr.Column():
295+
elev_deg = gr.Slider(
296+
label="Elevation (in degrees)",
297+
info="Elevation of the camera in the conditioning image, in degrees.",
298+
value=10.0,
299+
minimum=-10,
300+
maximum=30,
301+
)
302+
orbit = gr.Dropdown(
303+
["same elevation", "dynamic"],
304+
label="Orbit",
305+
info="Choose with orbit to generate",
306+
)
307+
plot_image = gr.Image()
308+
with gr.Accordion("Advanced options", open=False):
309+
seed = gr.Slider(
310+
label="Seed",
311+
value=23,
312+
randomize=True,
313+
minimum=0,
314+
maximum=max_64_bit_int,
315+
step=1,
316+
)
317+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
318+
decoding_t = gr.Slider(
319+
label="Decode n frames at a time",
320+
info="Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.",
321+
value=7,
322+
minimum=1,
323+
maximum=14,
324+
)
325+
326+
image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
327+
328+
elev_deg.change(gen_orbit, [orbit, elev_deg], plot_image)
329+
orbit.change(gen_orbit, [orbit, elev_deg], plot_image)
330+
# seed.change(gen_orbit, [orbit, elev_deg], plot_image)
331+
332+
generate_btn.click(
333+
fn=sample,
334+
inputs=[image, seed, randomize_seed, decoding_t],
335+
outputs=[video, seed],
336+
api_name="video",
337+
)
338+
339+
if __name__ == "__main__":
340+
demo.queue(max_size=20)
341+
demo.launch(share=True)

0 commit comments

Comments
 (0)