Skip to content

Commit 4d6e4aa

Browse files
authored
Add image generator to generate images for use with geneval (#172)
1 parent ab5a2f0 commit 4d6e4aa

File tree

3 files changed

+281
-22
lines changed

3 files changed

+281
-22
lines changed
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright 2022 MosaicML Diffusion authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Image generation for runnning evaluation with geneval."""
5+
6+
import json
7+
import os
8+
from typing import Dict, Optional, Union
9+
from urllib.parse import urlparse
10+
11+
import torch
12+
from composer.core import get_precision_context
13+
from composer.utils import dist
14+
from composer.utils.file_helpers import get_file
15+
from composer.utils.object_store import OCIObjectStore
16+
from diffusers import AutoPipelineForText2Image
17+
from torchvision.transforms.functional import to_pil_image
18+
from tqdm.auto import tqdm
19+
20+
21+
class GenevalImageGenerator:
22+
"""Image generator that generates images from the geneval prompt set and saves them.
23+
24+
Args:
25+
model (torch.nn.Module): The model to evaluate.
26+
geneval_prompts (str): Path to the prompts to use for geneval (ex: `geneval/prompts/evaluation_metadata.json`).
27+
load_path (str, optional): The path to load the model from. Default: ``None``.
28+
local_checkpoint_path (str, optional): The local path to save the model checkpoint. Default: ``'/tmp/model.pt'``.
29+
load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``.
30+
guidance_scale (float): The guidance scale to use for evaluation. Default: ``7.0``.
31+
height (int): The height of the generated images. Default: ``1024``.
32+
width (int): The width of the generated images. Default: ``1024``.
33+
images_per_prompt (int): The number of images to generate per prompt. Default: ``4``.
34+
load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``.
35+
seed (int): The seed to use for generation. Default: ``17``.
36+
output_bucket (str, Optional): The remote to save images to. Default: ``None``.
37+
output_prefix (str, Optional): The prefix to save images to. Default: ``None``.
38+
local_prefix (str): The local prefix to save images to. Default: ``/tmp``.
39+
additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method.
40+
hf_model: (bool, Optional): whether the model is HF or not. Default: ``False``.
41+
"""
42+
43+
def __init__(self,
44+
model: Union[torch.nn.Module, str],
45+
geneval_prompts: str,
46+
load_path: Optional[str] = None,
47+
local_checkpoint_path: str = '/tmp/model.pt',
48+
load_strict_model_weights: bool = True,
49+
guidance_scale: float = 7.0,
50+
height: int = 1024,
51+
width: int = 1024,
52+
images_per_prompt: int = 4,
53+
seed: int = 17,
54+
output_bucket: Optional[str] = None,
55+
output_prefix: Optional[str] = None,
56+
local_prefix: str = '/tmp',
57+
additional_generate_kwargs: Optional[Dict] = None,
58+
hf_model: Optional[bool] = False):
59+
60+
if isinstance(model, str) and hf_model == False:
61+
raise ValueError('Can only use strings for model with hf models!')
62+
self.hf_model = hf_model
63+
if hf_model or isinstance(model, str):
64+
if dist.get_local_rank() == 0:
65+
self.model = AutoPipelineForText2Image.from_pretrained(
66+
model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}')
67+
dist.barrier()
68+
self.model = AutoPipelineForText2Image.from_pretrained(
69+
model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}')
70+
dist.barrier()
71+
else:
72+
self.model = model
73+
# Load the geneval prompts
74+
self.geneval_prompts = geneval_prompts
75+
with open(geneval_prompts) as f:
76+
self.prompt_metadata = [json.loads(line) for line in f]
77+
self.load_path = load_path
78+
self.local_checkpoint_path = local_checkpoint_path
79+
self.load_strict_model_weights = load_strict_model_weights
80+
self.guidance_scale = guidance_scale
81+
self.height = height
82+
self.width = width
83+
self.images_per_prompt = images_per_prompt
84+
self.seed = seed
85+
self.generator = torch.Generator(device='cuda').manual_seed(self.seed)
86+
87+
self.output_bucket = output_bucket
88+
self.output_prefix = output_prefix if output_prefix is not None else ''
89+
self.local_prefix = local_prefix
90+
self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}
91+
92+
# Object store for uploading images
93+
if self.output_bucket is not None:
94+
parsed_remote_bucket = urlparse(self.output_bucket)
95+
if parsed_remote_bucket.scheme != 'oci':
96+
raise ValueError(f'Currently only OCI object stores are supported. Got {parsed_remote_bucket.scheme}.')
97+
self.object_store = OCIObjectStore(self.output_bucket.replace('oci://', ''), self.output_prefix)
98+
99+
# Download the model checkpoint if needed
100+
if self.load_path is not None and not isinstance(self.model, str):
101+
if dist.get_local_rank() == 0:
102+
get_file(path=self.load_path, destination=self.local_checkpoint_path, overwrite=True)
103+
with dist.local_rank_zero_download_and_wait(self.local_checkpoint_path):
104+
# Load the model
105+
state_dict = torch.load(self.local_checkpoint_path, map_location='cpu')
106+
for key in list(state_dict['state']['model'].keys()):
107+
if 'val_metrics.' in key:
108+
del state_dict['state']['model'][key]
109+
self.model.load_state_dict(state_dict['state']['model'], strict=self.load_strict_model_weights)
110+
self.model = self.model.cuda().eval()
111+
112+
def generate(self):
113+
"""Core image generation function. Generates images at a given guidance scale.
114+
115+
Args:
116+
guidance_scale (float): The guidance scale to use for image generation.
117+
"""
118+
os.makedirs(os.path.join(self.local_prefix, self.output_prefix), exist_ok=True)
119+
# Partition the dataset across the ranks. Note this partitions prompts, not repeats.
120+
dataset_len = len(self.prompt_metadata)
121+
samples_per_rank, remainder = divmod(dataset_len, dist.get_world_size())
122+
start_idx = dist.get_global_rank() * samples_per_rank + min(remainder, dist.get_global_rank())
123+
end_idx = start_idx + samples_per_rank
124+
if dist.get_global_rank() < remainder:
125+
end_idx += 1
126+
print(f'Rank {dist.get_global_rank()} processing samples {start_idx} to {end_idx} of {dataset_len} total.')
127+
# Iterate over the dataset
128+
for sample_id in tqdm(range(start_idx, end_idx)):
129+
metadata = self.prompt_metadata[sample_id]
130+
# Write the metadata jsonl
131+
output_dir = os.path.join(self.local_prefix, f'{sample_id:0>5}')
132+
os.makedirs(output_dir, exist_ok=True)
133+
with open(os.path.join(output_dir, 'metadata.jsonl'), 'w') as f:
134+
json.dump(metadata, f)
135+
caption = metadata['prompt']
136+
# Create dir for samples to live in
137+
sample_dir = os.path.join(output_dir, 'samples')
138+
os.makedirs(sample_dir, exist_ok=True)
139+
# Generate images from the captions. Take care to use a different seed for each image
140+
for i in range(self.images_per_prompt):
141+
seed = self.seed + i
142+
if self.hf_model:
143+
generated_image = self.model(prompt=caption,
144+
height=self.height,
145+
width=self.width,
146+
guidance_scale=self.guidance_scale,
147+
generator=self.generator,
148+
**self.additional_generate_kwargs).images[0]
149+
img = generated_image
150+
else:
151+
with get_precision_context('amp_fp16'):
152+
generated_image = self.model.generate(prompt=caption,
153+
height=self.height,
154+
width=self.width,
155+
guidance_scale=self.guidance_scale,
156+
seed=seed,
157+
progress_bar=False,
158+
**self.additional_generate_kwargs) # type: ignore
159+
img = to_pil_image(generated_image[0])
160+
# Save the images and metadata locally
161+
image_name = f'{i:05}.png'
162+
data_name = f'{i:05}.json'
163+
img_local_path = os.path.join(sample_dir, image_name)
164+
data_local_path = os.path.join(sample_dir, data_name)
165+
img.save(img_local_path)
166+
metadata = {
167+
'image_name': image_name,
168+
'prompt': caption,
169+
'guidance_scale': self.guidance_scale,
170+
'seed': seed
171+
}
172+
json.dump(metadata, open(f'{data_local_path}', 'w'))
173+
# Upload the image and metadata to cloud storage
174+
output_sample_prefix = os.path.join(self.output_prefix, f'{sample_id:0>5}', 'samples')
175+
if self.output_bucket is not None:
176+
self.object_store.upload_object(object_name=os.path.join(output_sample_prefix, image_name),
177+
filename=img_local_path)
178+
# Upload the metadata
179+
self.object_store.upload_object(object_name=os.path.join(output_sample_prefix, data_name),
180+
filename=data_local_path)

diffusion/generate.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""Generate images from a model."""
55

66
import operator
7-
from typing import List
7+
from typing import Any, List, Optional
88

99
import hydra
1010
from composer import Algorithm, ComposerModel
@@ -16,7 +16,20 @@
1616
from omegaconf import DictConfig
1717
from torch.utils.data import Dataset
1818

19-
from diffusion.evaluation.generate_images import ImageGenerator
19+
20+
def _make_dataset(config: DictConfig, tokenizer: Optional[Any] = None) -> Dataset:
21+
if config.hf_dataset:
22+
if dist.get_local_rank() == 0:
23+
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
24+
dist.barrier()
25+
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
26+
dist.barrier()
27+
elif tokenizer:
28+
dataset = hydra.utils.instantiate(config.dataset)
29+
30+
else:
31+
dataset: Dataset = hydra.utils.instantiate(config.dataset)
32+
return dataset
2033

2134

2235
def generate(config: DictConfig) -> None:
@@ -37,20 +50,6 @@ def generate(config: DictConfig) -> None:
3750

3851
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') else None
3952

40-
# The dataset to use for evaluation
41-
42-
if config.hf_dataset:
43-
if dist.get_local_rank() == 0:
44-
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
45-
dist.barrier()
46-
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
47-
dist.barrier()
48-
elif tokenizer:
49-
dataset = hydra.utils.instantiate(config.dataset)
50-
51-
else:
52-
dataset: Dataset = hydra.utils.instantiate(config.dataset)
53-
5453
# Build list of algorithms.
5554
algorithms: List[Algorithm] = []
5655

@@ -78,12 +77,15 @@ def generate(config: DictConfig) -> None:
7877
precision=Precision(ag_conf['precision']),
7978
optimizers=None,
8079
)
81-
82-
image_generator: ImageGenerator = hydra.utils.instantiate(config.generator,
83-
model=model,
84-
dataset=dataset,
85-
hf_model=config.hf_model,
86-
hf_dataset=config.hf_dataset)
80+
if 'dataset' in config:
81+
dataset = _make_dataset(config, tokenizer)
82+
image_generator = hydra.utils.instantiate(config.generator,
83+
model=model,
84+
dataset=dataset,
85+
hf_model=config.hf_model,
86+
hf_dataset=config.hf_dataset)
87+
else:
88+
image_generator = hydra.utils.instantiate(config.generator, model=model, hf_model=config.hf_model)
8789

8890
def generate_from_model():
8991
image_generator.generate()
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Example yaml for running geneval on FLUX.1-schnell model
2+
name: geneval-flux-1-schnell
3+
compute:
4+
cluster: # your cluster name
5+
instance: # your instance name
6+
gpus: # number of gpus
7+
env_variables:
8+
HYDRA_FULL_ERROR: '1'
9+
image: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04
10+
scheduling:
11+
resumable: false
12+
priority: medium
13+
max_retries: 0
14+
integrations:
15+
- integration_type: git_repo
16+
git_repo: mosaicml/diffusion
17+
git_branch: main
18+
pip_install: .[all] --no-deps # We install with no deps to use only specific deps needed for geneval
19+
- integration_type: pip_packages
20+
packages:
21+
- huggingface-hub[hf_transfer]>=0.23.2
22+
- numpy==1.26.4
23+
- pandas
24+
- open_clip_torch
25+
- clip-benchmark
26+
- openmim
27+
- sentencepiece
28+
- mosaicml
29+
- mosaicml-streaming
30+
- hydra-core
31+
- hydra-colorlog
32+
- diffusers[torch]==0.30.3
33+
- transformers[torch]==4.44.2
34+
- torchmetrics[image]
35+
- lpips
36+
- clean-fid
37+
- gradio
38+
- datasets
39+
- peft
40+
command: 'cd diffusion
41+
42+
pip install clip@git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33
43+
44+
mim install mmengine mmcv-full==1.7.2
45+
46+
apt-get update && apt-get install libgl1-mesa-glx -y
47+
48+
git clone https://github.com/djghosh13/geneval.git
49+
50+
git clone https://github.com/open-mmlab/mmdetection.git
51+
52+
cd mmdetection; git checkout 2.x; pip install -v -e .; cd ..
53+
54+
composer run_generation.py --config-path /mnt/config --config-name parameters
55+
56+
cd geneval
57+
58+
./evaluation/download_models.sh eval_models
59+
60+
python evaluation/evaluate_images.py /tmp/geneval-images --outfile outputs.jsonl --model-path eval_models
61+
62+
python evaluation/summary_scores.py outputs.jsonl
63+
'
64+
parameters:
65+
seed: 18
66+
dist_timeout: 300
67+
hf_model: true # We will use a model from huggingface
68+
model:
69+
name: black-forest-labs/FLUX.1-schnell # Model name from huggingface
70+
generator:
71+
_target_: diffusion.evaluation.generate_geneval_images.GenevalImageGenerator
72+
geneval_prompts: geneval/prompts/evaluation_metadata.jsonl # Path to geneval prompts json
73+
height: 1024 # Generated image height
74+
width: 1024 # Generated image width
75+
local_prefix: /tmp/geneval-images # Local path to save images to. Needed for geneval to read images from.
76+
output_bucket: # Your output oci bucket name (optional)
77+
output_prefix: # Your output prefix (optional)

0 commit comments

Comments
 (0)