We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent eb2ef31 commit 1bd4c9eCopy full SHA for 1bd4c9e
examples/text_to_image/train_text_to_image_flax.py
@@ -340,11 +340,10 @@ def preprocess_train(examples):
340
341
return examples
342
343
- if jax.process_index() == 0:
344
- if args.max_train_samples is not None:
345
- dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
346
# Set the training transforms
347
- train_dataset = dataset["train"].with_transform(preprocess_train)
+ train_dataset = dataset["train"].with_transform(preprocess_train)
348
349
def collate_fn(examples):
350
pixel_values = torch.stack([example["pixel_values"] for example in examples])
0 commit comments