66
77
88class ScoreSdeVePipeline (DiffusionPipeline ):
9- def __init__ (self , model , scheduler ):
9+ def __init__ (self , unet , scheduler ):
1010 super ().__init__ ()
11- self .register_modules (model = model , scheduler = scheduler )
11+ self .register_modules (unet = unet , scheduler = scheduler )
1212
1313 @torch .no_grad ()
1414 def __call__ (self , batch_size = 1 , num_inference_steps = 2000 , generator = None , torch_device = None , output_type = "pil" ):
15+
1516 if torch_device is None :
1617 torch_device = "cuda" if torch .cuda .is_available () else "cpu"
1718
18- img_size = self .model .config .sample_size
19+ img_size = self .unet .config .sample_size
1920 shape = (batch_size , 3 , img_size , img_size )
2021
21- model = self .model .to (torch_device )
22+ model = self .unet .to (torch_device )
2223
2324 sample = torch .randn (* shape ) * self .scheduler .config .sigma_max
2425 sample = sample .to (torch_device )
@@ -31,7 +32,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch
3132
3233 # correction step
3334 for _ in range (self .scheduler .correct_steps ):
34- model_output = self .model (sample , sigma_t )["sample" ]
35+ model_output = self .unet (sample , sigma_t )["sample" ]
3536 sample = self .scheduler .step_correct (model_output , sample )["prev_sample" ]
3637
3738 # prediction step
@@ -40,7 +41,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch
4041
4142 sample , sample_mean = output ["prev_sample" ], output ["prev_sample_mean" ]
4243
43- sample = sample .clamp (0 , 1 )
44+ sample = sample_mean .clamp (0 , 1 )
4445 sample = sample .cpu ().permute (0 , 2 , 3 , 1 ).numpy ()
4546 if output_type == "pil" :
4647 sample = self .numpy_to_pil (sample )
0 commit comments