Skip to content

Commit d0cf681

Browse files
authored
[Tests] add: tests for t2i adapter training. (huggingface#4947)
add: tests for t2i adapter training.
1 parent dfec61f commit d0cf681

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

examples/t2i_adapter/train_t2i_adapter_sdxl.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,13 @@ def parse_args(input_args=None):
245245
default=None,
246246
help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
247247
)
248+
parser.add_argument(
249+
"--adapter_model_name_or_path",
250+
type=str,
251+
default=None,
252+
help="Path to pretrained adapter model or model identifier from huggingface.co/models."
253+
" If not specified adapter weights are initialized w.r.t the configurations of SDXL.",
254+
)
248255
parser.add_argument(
249256
"--revision",
250257
type=str,
@@ -840,14 +847,18 @@ def main(args):
840847
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
841848
)
842849

843-
logger.info("Initializing t2iadapter weights from unet")
844-
t2iadapter = T2IAdapter(
845-
in_channels=3,
846-
channels=(320, 640, 1280, 1280),
847-
num_res_blocks=2,
848-
downscale_factor=16,
849-
adapter_type="full_adapter_xl",
850-
)
850+
if args.adapter_model_name_or_path:
851+
logger.info("Loading existing adapter weights.")
852+
t2iadapter = T2IAdapter.from_pretrained(args.adapter_model_name_or_path)
853+
else:
854+
logger.info("Initializing t2iadapter weights.")
855+
t2iadapter = T2IAdapter(
856+
in_channels=3,
857+
channels=(320, 640, 1280, 1280),
858+
num_res_blocks=2,
859+
downscale_factor=16,
860+
adapter_type="full_adapter_xl",
861+
)
851862

852863
# `accelerate` 0.16.0 will have better support for customized saving
853864
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):

examples/test_examples.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,6 +1528,25 @@ def test_controlnet_sdxl(self):
15281528

15291529
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
15301530

1531+
def test_t2i_adapter_sdxl(self):
1532+
with tempfile.TemporaryDirectory() as tmpdir:
1533+
test_args = f"""
1534+
examples/t2i_adapter/train_t2i_adapter_sdxl.py
1535+
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
1536+
--adapter_model_name_or_path=hf-internal-testing/tiny-adapter
1537+
--dataset_name=hf-internal-testing/fill10
1538+
--output_dir={tmpdir}
1539+
--resolution=64
1540+
--train_batch_size=1
1541+
--gradient_accumulation_steps=1
1542+
--max_train_steps=9
1543+
--checkpointing_steps=2
1544+
""".split()
1545+
1546+
run_command(self._launch_args + test_args)
1547+
1548+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
1549+
15311550
def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
15321551
with tempfile.TemporaryDirectory() as tmpdir:
15331552
test_args = f"""

0 commit comments

Comments
 (0)