1414from composer .core import get_precision_context
1515from composer .loggers import LoggerDestination
1616from composer .utils import dist
17- from torch .utils .data import DataLoader
17+ from torch .utils .data import Dataset
1818from torchmetrics .multimodal import CLIPScore
19- from torchvision .transforms .functional import to_pil_image
19+ from torchvision .transforms .functional import pil_to_tensor , to_pil_image
2020from tqdm .auto import tqdm
21- from transformers import PreTrainedTokenizerBase
2221
2322os .environ ['TOKENIZERS_PARALLELISM' ] = 'false'
2423
@@ -32,7 +31,7 @@ class CleanFIDEvaluator:
3231
3332 Args:
3433 model (ComposerModel): The model to evaluate.
35- eval_dataloader (DataLoader ): The dataloader to use for evaluation .
34+ dataset (Dataset ): The dataset to use the prompts from .
3635 clip_metric (CLIPScore): The CLIPScore metric to use for evaluation.
3736 load_path (str, optional): The path to load the model from. Default: ``None``.
3837 guidance_scales (List[float]): The guidance scales to use for evaluation.
@@ -52,13 +51,14 @@ class CleanFIDEvaluator:
5251 default_prompt (Optional[str]): An optional default prompt to add before each eval prompt. Default: ``None``.
5352 default_negative_prompt (Optional[str]): An optional default negative prompt to add before each
5453 negative prompt. Default: ``None``.
54+ sdxl_conditioning (bool): Whether or not to include SDXL conditioning in the evaluation. Default: ``False``.
5555 additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method.
5656
5757 """
5858
5959 def __init__ (self ,
6060 model : ComposerModel ,
61- eval_dataloader : DataLoader ,
61+ dataset : Dataset ,
6262 clip_metric : CLIPScore ,
6363 load_path : Optional [str ] = None ,
6464 guidance_scales : Optional [List [float ]] = None ,
@@ -75,10 +75,10 @@ def __init__(self,
7575 prompts : Optional [List [str ]] = None ,
7676 default_prompt : Optional [str ] = None ,
7777 default_negative_prompt : Optional [str ] = None ,
78+ sdxl_conditioning : bool = False ,
7879 additional_generate_kwargs : Optional [Dict ] = None ):
7980 self .model = model
80- self .tokenizer : PreTrainedTokenizerBase = model .tokenizer
81- self .eval_dataloader = eval_dataloader
81+ self .dataset = dataset
8282 self .clip_metric = clip_metric
8383 self .load_path = load_path
8484 self .guidance_scales = guidance_scales if guidance_scales is not None else [1.0 ]
@@ -89,20 +89,19 @@ def __init__(self,
8989 self .loggers = loggers
9090 self .seed = seed
9191 self .output_dir = output_dir
92- self .num_samples = num_samples if num_samples is not None else float ( 'inf' )
92+ self .num_samples = num_samples
9393 self .precision = precision
9494 self .prompts = prompts if prompts is not None else ['A shiba inu wearing a blue sweater' ]
9595 self .default_prompt = default_prompt
9696 self .default_negative_prompt = default_negative_prompt
97+ self .sdxl_conditioning = sdxl_conditioning
9798 self .additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}
98- self .sdxl = model .sdxl
9999
100100 # Load the model
101101 trainer = Trainer (model = self .model ,
102102 load_path = self .load_path ,
103103 load_weights_only = True ,
104104 load_strict_model_weights = load_strict_model_weights ,
105- eval_dataloader = self .eval_dataloader ,
106105 seed = self .seed ,
107106 loggers = self .loggers )
108107 self .trainer = trainer
@@ -139,18 +138,27 @@ def _generate_images(self, guidance_scale: float):
139138
140139 # Storage for prompts
141140 prompts = {}
142- # Iterate over the eval dataloader
143- num_batches = len (self .eval_dataloader )
144- starting_seed = self .seed + num_batches * dist .get_local_rank ()
145- for batch_id , batch in tqdm (enumerate (self .eval_dataloader )):
146- # Break if enough samples have been generated
147- if batch_id * self .batch_size * dist .get_world_size () >= self .num_samples :
148- break
149-
150- real_images = batch [self .image_key ]
151- tokenized_captions = batch [self .caption_key ]
152- # Get the prompts from the tokens
153- text_captions = self .tokenizer .batch_decode (tokenized_captions , skip_special_tokens = True )
141+ # Partition the dataset across the ranks
142+ dataset_len = self .dataset .num_samples # type: ignore
143+ # Truncate the dataset if num_samples is specified
144+ if self .num_samples is not None and self .num_samples <= dataset_len :
145+ dataset_len = self .num_samples
146+ elif self .num_samples is not None and self .num_samples > dataset_len :
147+ raise ValueError (f'num_samples { self .num_samples } is greater than the dataset length { dataset_len } .' )
148+ samples_per_rank , remainder = divmod (dataset_len , dist .get_world_size ())
149+ start_idx = dist .get_global_rank () * samples_per_rank + min (remainder , dist .get_global_rank ())
150+ end_idx = start_idx + samples_per_rank
151+ if dist .get_global_rank () < remainder :
152+ end_idx += 1
153+ print (f'Rank { dist .get_global_rank ()} processing samples { start_idx } to { end_idx } of { dataset_len } total.' )
154+ # Iterate over the dataset
155+ for sample_id in tqdm (range (start_idx , end_idx )):
156+ # Set a unique seed for this sample to ensure reproducible but different randomness
157+ seed = self .seed + sample_id
158+ # Image and caption come from the dataset. Note the caption is untokenized
159+ sample = self .dataset [sample_id ]
160+ real_images = pil_to_tensor (sample [self .image_key ]).unsqueeze (0 ) / 255.0
161+ text_captions = sample [self .caption_key ]
154162 # Add default prompts if specified
155163 augmented_captions = text_captions
156164 augmented_negative_prompt = None
@@ -159,15 +167,12 @@ def _generate_images(self, guidance_scale: float):
159167 if self .default_negative_prompt :
160168 augmented_negative_prompt = [f'{ self .default_negative_prompt } ' for _ in text_captions ]
161169
162- if self .sdxl :
163- crop_params = batch [ 'cond_crops_coords_top_left' ]
164- input_size_params = batch [ 'cond_original_size' ]
170+ if self .sdxl_conditioning :
171+ crop_params = torch . tensor ([ 0 , 0 ]). unsqueeze ( 0 )
172+ input_size_params = torch . tensor ([ self . size , self . size ]). unsqueeze ( 0 )
165173 else :
166174 crop_params = None
167175 input_size_params = None
168-
169- # Ensure a new seed for each batch, as randomness in model.generate is fixed.
170- seed = starting_seed + batch_id
171176 # Generate images from the captions
172177 with get_precision_context (self .precision ):
173178 generated_images = self .model .generate (prompt = augmented_captions ,
@@ -188,11 +193,11 @@ def _generate_images(self, guidance_scale: float):
188193 f'Images are expected to be in the range [0, 1]. Got max { real_images .max ()} and min { real_images .min ()} '
189194 )
190195 for i , img in enumerate (real_images ):
191- to_pil_image (img ).save (f'{ real_image_path } /{ batch_id } _ { i } _rank_{ dist .get_local_rank ()} .png' )
192- prompts [f'{ batch_id } _ { i } _rank_{ dist .get_local_rank ()} ' ] = text_captions [i ]
196+ to_pil_image (img ).save (f'{ real_image_path } /{ sample_id } _rank_{ dist .get_local_rank ()} .png' )
197+ prompts [f'{ sample_id } _rank_{ dist .get_local_rank ()} ' ] = text_captions [i ]
193198 # Save the generated images
194199 for i , img in enumerate (generated_images ):
195- to_pil_image (img ).save (f'{ gen_image_path } /{ batch_id } _ { i } _rank_{ dist .get_local_rank ()} .png' )
200+ to_pil_image (img ).save (f'{ gen_image_path } /{ sample_id } _rank_{ dist .get_local_rank ()} .png' )
196201
197202 # Save the prompts as json
198203 json .dump (prompts , open (f'{ real_image_path } /prompts_rank_{ dist .get_local_rank ()} .json' , 'w' ))
0 commit comments