|
| 1 | +# Copyright 2022 MosaicML Diffusion authors |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +"""Image generation for runnning evaluation with geneval.""" |
| 5 | + |
| 6 | +import json |
| 7 | +import os |
| 8 | +from typing import Dict, Optional, Union |
| 9 | +from urllib.parse import urlparse |
| 10 | + |
| 11 | +import torch |
| 12 | +from composer.core import get_precision_context |
| 13 | +from composer.utils import dist |
| 14 | +from composer.utils.file_helpers import get_file |
| 15 | +from composer.utils.object_store import OCIObjectStore |
| 16 | +from diffusers import AutoPipelineForText2Image |
| 17 | +from torchvision.transforms.functional import to_pil_image |
| 18 | +from tqdm.auto import tqdm |
| 19 | + |
| 20 | + |
| 21 | +class GenevalImageGenerator: |
| 22 | + """Image generator that generates images from the geneval prompt set and saves them. |
| 23 | +
|
| 24 | + Args: |
| 25 | + model (torch.nn.Module): The model to evaluate. |
| 26 | + geneval_prompts (str): Path to the prompts to use for geneval (ex: `geneval/prompts/evaluation_metadata.json`). |
| 27 | + load_path (str, optional): The path to load the model from. Default: ``None``. |
| 28 | + local_checkpoint_path (str, optional): The local path to save the model checkpoint. Default: ``'/tmp/model.pt'``. |
| 29 | + load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``. |
| 30 | + guidance_scale (float): The guidance scale to use for evaluation. Default: ``7.0``. |
| 31 | + height (int): The height of the generated images. Default: ``1024``. |
| 32 | + width (int): The width of the generated images. Default: ``1024``. |
| 33 | + images_per_prompt (int): The number of images to generate per prompt. Default: ``4``. |
| 34 | + load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``. |
| 35 | + seed (int): The seed to use for generation. Default: ``17``. |
| 36 | + output_bucket (str, Optional): The remote to save images to. Default: ``None``. |
| 37 | + output_prefix (str, Optional): The prefix to save images to. Default: ``None``. |
| 38 | + local_prefix (str): The local prefix to save images to. Default: ``/tmp``. |
| 39 | + additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method. |
| 40 | + hf_model: (bool, Optional): whether the model is HF or not. Default: ``False``. |
| 41 | + """ |
| 42 | + |
| 43 | + def __init__(self, |
| 44 | + model: Union[torch.nn.Module, str], |
| 45 | + geneval_prompts: str, |
| 46 | + load_path: Optional[str] = None, |
| 47 | + local_checkpoint_path: str = '/tmp/model.pt', |
| 48 | + load_strict_model_weights: bool = True, |
| 49 | + guidance_scale: float = 7.0, |
| 50 | + height: int = 1024, |
| 51 | + width: int = 1024, |
| 52 | + images_per_prompt: int = 4, |
| 53 | + seed: int = 17, |
| 54 | + output_bucket: Optional[str] = None, |
| 55 | + output_prefix: Optional[str] = None, |
| 56 | + local_prefix: str = '/tmp', |
| 57 | + additional_generate_kwargs: Optional[Dict] = None, |
| 58 | + hf_model: Optional[bool] = False): |
| 59 | + |
| 60 | + if isinstance(model, str) and hf_model == False: |
| 61 | + raise ValueError('Can only use strings for model with hf models!') |
| 62 | + self.hf_model = hf_model |
| 63 | + if hf_model or isinstance(model, str): |
| 64 | + if dist.get_local_rank() == 0: |
| 65 | + self.model = AutoPipelineForText2Image.from_pretrained( |
| 66 | + model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}') |
| 67 | + dist.barrier() |
| 68 | + self.model = AutoPipelineForText2Image.from_pretrained( |
| 69 | + model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}') |
| 70 | + dist.barrier() |
| 71 | + else: |
| 72 | + self.model = model |
| 73 | + # Load the geneval prompts |
| 74 | + self.geneval_prompts = geneval_prompts |
| 75 | + with open(geneval_prompts) as f: |
| 76 | + self.prompt_metadata = [json.loads(line) for line in f] |
| 77 | + self.load_path = load_path |
| 78 | + self.local_checkpoint_path = local_checkpoint_path |
| 79 | + self.load_strict_model_weights = load_strict_model_weights |
| 80 | + self.guidance_scale = guidance_scale |
| 81 | + self.height = height |
| 82 | + self.width = width |
| 83 | + self.images_per_prompt = images_per_prompt |
| 84 | + self.seed = seed |
| 85 | + self.generator = torch.Generator(device='cuda').manual_seed(self.seed) |
| 86 | + |
| 87 | + self.output_bucket = output_bucket |
| 88 | + self.output_prefix = output_prefix if output_prefix is not None else '' |
| 89 | + self.local_prefix = local_prefix |
| 90 | + self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {} |
| 91 | + |
| 92 | + # Object store for uploading images |
| 93 | + if self.output_bucket is not None: |
| 94 | + parsed_remote_bucket = urlparse(self.output_bucket) |
| 95 | + if parsed_remote_bucket.scheme != 'oci': |
| 96 | + raise ValueError(f'Currently only OCI object stores are supported. Got {parsed_remote_bucket.scheme}.') |
| 97 | + self.object_store = OCIObjectStore(self.output_bucket.replace('oci://', ''), self.output_prefix) |
| 98 | + |
| 99 | + # Download the model checkpoint if needed |
| 100 | + if self.load_path is not None and not isinstance(self.model, str): |
| 101 | + if dist.get_local_rank() == 0: |
| 102 | + get_file(path=self.load_path, destination=self.local_checkpoint_path, overwrite=True) |
| 103 | + with dist.local_rank_zero_download_and_wait(self.local_checkpoint_path): |
| 104 | + # Load the model |
| 105 | + state_dict = torch.load(self.local_checkpoint_path, map_location='cpu') |
| 106 | + for key in list(state_dict['state']['model'].keys()): |
| 107 | + if 'val_metrics.' in key: |
| 108 | + del state_dict['state']['model'][key] |
| 109 | + self.model.load_state_dict(state_dict['state']['model'], strict=self.load_strict_model_weights) |
| 110 | + self.model = self.model.cuda().eval() |
| 111 | + |
| 112 | + def generate(self): |
| 113 | + """Core image generation function. Generates images at a given guidance scale. |
| 114 | +
|
| 115 | + Args: |
| 116 | + guidance_scale (float): The guidance scale to use for image generation. |
| 117 | + """ |
| 118 | + os.makedirs(os.path.join(self.local_prefix, self.output_prefix), exist_ok=True) |
| 119 | + # Partition the dataset across the ranks. Note this partitions prompts, not repeats. |
| 120 | + dataset_len = len(self.prompt_metadata) |
| 121 | + samples_per_rank, remainder = divmod(dataset_len, dist.get_world_size()) |
| 122 | + start_idx = dist.get_global_rank() * samples_per_rank + min(remainder, dist.get_global_rank()) |
| 123 | + end_idx = start_idx + samples_per_rank |
| 124 | + if dist.get_global_rank() < remainder: |
| 125 | + end_idx += 1 |
| 126 | + print(f'Rank {dist.get_global_rank()} processing samples {start_idx} to {end_idx} of {dataset_len} total.') |
| 127 | + # Iterate over the dataset |
| 128 | + for sample_id in tqdm(range(start_idx, end_idx)): |
| 129 | + metadata = self.prompt_metadata[sample_id] |
| 130 | + # Write the metadata jsonl |
| 131 | + output_dir = os.path.join(self.local_prefix, f'{sample_id:0>5}') |
| 132 | + os.makedirs(output_dir, exist_ok=True) |
| 133 | + with open(os.path.join(output_dir, 'metadata.jsonl'), 'w') as f: |
| 134 | + json.dump(metadata, f) |
| 135 | + caption = metadata['prompt'] |
| 136 | + # Create dir for samples to live in |
| 137 | + sample_dir = os.path.join(output_dir, 'samples') |
| 138 | + os.makedirs(sample_dir, exist_ok=True) |
| 139 | + # Generate images from the captions. Take care to use a different seed for each image |
| 140 | + for i in range(self.images_per_prompt): |
| 141 | + seed = self.seed + i |
| 142 | + if self.hf_model: |
| 143 | + generated_image = self.model(prompt=caption, |
| 144 | + height=self.height, |
| 145 | + width=self.width, |
| 146 | + guidance_scale=self.guidance_scale, |
| 147 | + generator=self.generator, |
| 148 | + **self.additional_generate_kwargs).images[0] |
| 149 | + img = generated_image |
| 150 | + else: |
| 151 | + with get_precision_context('amp_fp16'): |
| 152 | + generated_image = self.model.generate(prompt=caption, |
| 153 | + height=self.height, |
| 154 | + width=self.width, |
| 155 | + guidance_scale=self.guidance_scale, |
| 156 | + seed=seed, |
| 157 | + progress_bar=False, |
| 158 | + **self.additional_generate_kwargs) # type: ignore |
| 159 | + img = to_pil_image(generated_image[0]) |
| 160 | + # Save the images and metadata locally |
| 161 | + image_name = f'{i:05}.png' |
| 162 | + data_name = f'{i:05}.json' |
| 163 | + img_local_path = os.path.join(sample_dir, image_name) |
| 164 | + data_local_path = os.path.join(sample_dir, data_name) |
| 165 | + img.save(img_local_path) |
| 166 | + metadata = { |
| 167 | + 'image_name': image_name, |
| 168 | + 'prompt': caption, |
| 169 | + 'guidance_scale': self.guidance_scale, |
| 170 | + 'seed': seed |
| 171 | + } |
| 172 | + json.dump(metadata, open(f'{data_local_path}', 'w')) |
| 173 | + # Upload the image and metadata to cloud storage |
| 174 | + output_sample_prefix = os.path.join(self.output_prefix, f'{sample_id:0>5}', 'samples') |
| 175 | + if self.output_bucket is not None: |
| 176 | + self.object_store.upload_object(object_name=os.path.join(output_sample_prefix, image_name), |
| 177 | + filename=img_local_path) |
| 178 | + # Upload the metadata |
| 179 | + self.object_store.upload_object(object_name=os.path.join(output_sample_prefix, data_name), |
| 180 | + filename=data_local_path) |
0 commit comments