Skip to content
119 changes: 100 additions & 19 deletions examples/text_to_image/inference_tpu_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def parser(args):
parser.add_argument(
'--batch-size',
type=int,
default=8,
default=2, # 8,
help='Number of images to generate'
)

Expand All @@ -26,7 +26,7 @@ def parser(args):
parser.add_argument(
'--inf-steps',
type=int,
default=30,
default=2, # 30,
help='Number of itterations to run the benchmark.'
)

Expand All @@ -35,33 +35,114 @@ def parser(args):

def main(args):
server = xp.start_server(9012)
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-0.9",
use_safetensors=True,
)
# pipe = DiffusionPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-base-0.9",
# use_safetensors=True,
# )
device = xm.xla_device()
pipe.to(device)
# pipe.to(device)

bs = args.batch_size
inference_steps = args.inf_steps
height = width = args.width
bs = args.batch_size # 1
inference_steps = args.inf_steps # 2
height = width = args.width # 512

prompts = ["a photo of an astronaut riding a horse on mars"] * bs
print(f'batch size = {bs}, inference steps = {inference_steps}',
f'height = width = {width}',
flush=True
)

pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-0.9",
use_safetensors=True,
)
pipe.to(device)

iters = 15
print('starting inference', flush=True)
start2 = time()
iters = 3
for i in range(iters):
start = time()
image = pipe(prompts,
num_inference_steps=inference_steps,
height=height,
width=width,
).images[0]
print(f'Step {i} inference time {time()-start} sec', flush=True)
# pipe2 = DiffusionPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-base-0.9",
# use_safetensors=True,
# )
# pipe2.to(device)
image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts,
num_inference_steps=2, # inference_steps,
height=512, # height,
width=512, # width,
).images[0]
print(f'Call pipeline without _xla_while_loop for three times used {time()-start2} sec', flush=True)


import torch
import torch_xla.experimental.fori_loop
from torch._higher_order_ops.while_loop import while_loop
def cond_fn(init, limit_value):
return limit_value[0] <= init[0]

def body_fn(init, limit_value):
# one_value = torch.ones(1, dtype=torch.int32, device=device)
# two_value = limit_value.clone()
# start = time()
# pipe = DiffusionPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-base-0.9",
# use_safetensors=True,
# )
# # device = xm.xla_device()
# pipe.to(device)
image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts,
num_inference_steps=2, # inference_steps,
height=512, # height,
width=512, # width,
).images[0]
# image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts,
# num_inference_steps=2, # inference_steps,
# height=512, # height,
# width=512, # width,
# ).images[0]
# print("type of image: ", type(image))
# print(f'Step {i} inference time {time()-start} sec', flush=True)
one_value = torch.ones(1, dtype=torch.int32, device=device)
two_value = limit_value.clone()
return (torch.sub(init, one_value), two_value)

start = time()
# iters = 3
init = torch.tensor([3], dtype=torch.int32, device=device)
limit_value = torch.tensor([0], dtype=torch.int32, device=device)
# res = while_loop(cond_fn, body_fn, (init, limit_value))
from torch_xla.experimental.fori_loop import _xla_while_loop
res = _xla_while_loop(cond_fn, body_fn, (init, limit_value))
print(f'Call pipeline with _xla_while_loop for three times used {time()-start} sec', flush=True)
print("result of while_loop: ", res)
# expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
# self.assertEqual(expected, res)

# start2 = time()
# iters = 3
# for i in range(iters):
# # pipe2 = DiffusionPipeline.from_pretrained(
# # "stabilityai/stable-diffusion-xl-base-0.9",
# # use_safetensors=True,
# # )
# pipe2.to(device)
# image2 = pipe2(["a photo of an astronaut riding a horse on mars"], # prompts,
# num_inference_steps=2, # inference_steps,
# height=512, # height,
# width=512, # width,
# ).images[0]
# print(f'Call pipeline without _xla_while_loop for three times used {time()-start2} sec', flush=True)

# iters = 1 # 15
# print('starting inference', flush=True)
# for i in range(iters):
# start = time()
# image = pipe(prompts,
# num_inference_steps=inference_steps,
# height=height,
# width=width,
# ).images[0]
# print(f'Step {i} inference time {time()-start} sec', flush=True)


if __name__ == '__main__':
Expand Down