diff --git a/README.md b/README.md index 8c570e5..48c8287 100644 --- a/README.md +++ b/README.md @@ -3,85 +3,59 @@ Experiments with "FixMatch" on Cifar10 dataset. Based on ["FixMatch: Simplifying Semi-Supervised Learning withConsistency and Confidence"](https://arxiv.org/abs/2001.07685) +and its official [code](https://github.com/google-research/fixmatch). ## Requirements ```bash +pip install --upgrade --pre hydra-core tensorboardX pip install --upgrade --pre pytorch-ignite ``` ## Training ```bash -python -u main_fixmatch.py -# or python -u main_fixmatch.py --params "data_path=/path/to/cifar10" +python -u main_fixmatch.py model=WRN-28-2 ``` This script automatically trains in multiple GPUs (`torch.nn.DistributedParallel`). -### Distributed Data Parallel (DDP) on multiple GPUs (Experimental) +If it is needed to specify input/output folder : +``` +python -u main_fixmatch.py dataflow.data_path=/data/cifar10/ hydra.run.dir=/output-fixmatch model=WRN-28-2 +``` -For example, training on 2 GPUs +To use wandb logger, we need login and run with `online_exp_tracking.wandb=true`: ```bash -python -u -m torch.distributed.launch --nproc_per_node=2 main_fixmatch.py --params="distributed=True" +wandb login +python -u main_fixmatch.py model=WRN-28-2 online_exp_tracking.wandb=true ``` -### TPU(s) on Colab (Experimental) - -#### Installation +To see other options: ```bash -VERSION = "1.5" -!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py -!python pytorch-xla-env-setup.py --version $VERSION +python -u main_fixmatch.py --help ``` -#### Single TPU +### Training curves visualization + +By default, we use Tensorboard to log training curves + ```bash -python -u main_fixmatch.py --params="device='xla'" +tensorboard --logdir=/tmp/output-fixmatch-cifar10-hydra/ ``` -#### 8 TPUs on Colab +### Distributed Data Parallel (DDP) on multiple GPUs (Experimental) + +For example, training on 2 GPUs ```bash -python -u main_fixmatch.py --params="device='xla';distributed=True" +python -u -m torch.distributed.launch --nproc_per_node=2 main_fixmatch.py model=WRN-28-2 distributed.backend=nccl ``` -## TODO - -* [x] Resume training from existing checkpoint: - * [x] save/load CTA - * [x] save ema model - -* [ ] DDP: - * [x] Synchronize CTA across processes - * [x] Unified GPU and TPU approach - * [ ] Bug: DDP performances are worse than DP on the first epochs - -* [ ] Logging to an online platform: NeptuneML or Trains or W&B - -* [ ] Replace PIL augmentations with Albumentations - -```python -class BlurLimitSampler: - def __init__(self, blur, weights): - self.blur = blur # [3, 5, 7] - self.weights = weights # [0.1, 0.5, 0.4] - def get_params(self): - return {"ksize": int(random.choice(self.blur, p=self.weights))} - -class Blur(ImageOnlyTransform): - def __init__(self, blur_limit, always_apply=False, p=0.5): - super(Blur, self).__init__(always_apply, p) - self.blur_limit = blur_limit - - def apply(self, image, ksize=3, **params): - return F.blur(image, ksize) - - def get_params(self): - if isinstance(self.blur_limit, BlurLimitSampler): - return self.blur_limit.get_params() - return {"ksize": int(random.choice(np.arange(self.blur_limit[0], self.blur_limit[1] + 1, 2)))} - - def get_transform_init_args_names(self): - return ("blur_limit",) -``` \ No newline at end of file +### TPU(s) on Colab (Experimental) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1ZoWz1-a3bpj1xMxpM2K2qQ4Y9xvtdGWO) + +For example, training on 8 TPUs in distributed mode: +```bash +python -u main_fixmatch.py model=resnet18 distributed.backend=xla-tpu distributed.nproc_per_node=8 +# or python -u main_fixmatch.py model=WRN-28-2 distributed.backend=xla-tpu distributed.nproc_per_node=8 +``` diff --git a/TODO b/TODO new file mode 100644 index 0000000..3244f9e --- /dev/null +++ b/TODO @@ -0,0 +1,38 @@ +## TODO + +* [x] Resume training from existing checkpoint: + * [x] save/load CTA + * [x] save ema model + +* [ ] DDP: + * [x] Synchronize CTA across processes + * [ ] Bug: DDP performances are worse than DP on the first epochs + +* [x] Logging to an online platform: W&B + +* [ ] Replace PIL augmentations with Albumentations + +```python +class BlurLimitSampler: + def __init__(self, blur, weights): + self.blur = blur # [3, 5, 7] + self.weights = weights # [0.1, 0.5, 0.4] + def get_params(self): + return {"ksize": int(random.choice(self.blur, p=self.weights))} + +class Blur(ImageOnlyTransform): + def __init__(self, blur_limit, always_apply=False, p=0.5): + super(Blur, self).__init__(always_apply, p) + self.blur_limit = blur_limit + + def apply(self, image, ksize=3, **params): + return F.blur(image, ksize) + + def get_params(self): + if isinstance(self.blur_limit, BlurLimitSampler): + return self.blur_limit.get_params() + return {"ksize": int(random.choice(np.arange(self.blur_limit[0], self.blur_limit[1] + 1, 2)))} + + def get_transform_init_args_names(self): + return ("blur_limit",) +``` \ No newline at end of file diff --git a/base_train.py b/base_train.py deleted file mode 100644 index 7d51c8d..0000000 --- a/base_train.py +++ /dev/null @@ -1,348 +0,0 @@ -import argparse -from pathlib import Path -import yaml -import hashlib - -import torch -import torch.nn as nn -import torch.optim as optim - -import ignite -from ignite.engine import Events, Engine, create_supervised_evaluator -from ignite.metrics import Accuracy, Precision, Recall -from ignite.handlers import Checkpoint -from ignite.utils import convert_tensor - -from ignite.contrib.engines import common -from ignite.contrib.handlers import ProgressBar -from ignite.contrib.handlers.time_profilers import BasicTimeProfiler - -import utils -import dist_utils - - -def run(trainer, config): - assert isinstance(trainer, BaseTrainer) - debug = config["debug"] - - device = config["device"] - if device == "xla": - import torch_xla.core.xla_model as xm - device = xm.xla_device() - distributed = config["distributed"] - - local_rank = config["local_rank"] - rank = dist_utils.get_rank() - - if distributed and device == "cuda": - torch.cuda.set_device(local_rank) - - torch.manual_seed(config["seed"] + rank) - - cta = utils.get_default_cta() - - supervised_train_loader_iter, unsupervised_train_loader_iter, cta_probe_loader_iter = \ - utils.get_dataflow_iters(config, cta, distributed) - - test_loader = utils.get_test_loader( - config["data_path"], - transforms=utils.test_transforms, - batch_size=config["batch_size"], - num_workers=config["num_workers"] - ) - - model, ema_model, optimizer = utils.get_models_optimizer(config, distributed) - - sup_criterion = nn.CrossEntropyLoss() - unsup_criterion = nn.CrossEntropyLoss(reduction='none') - - num_epochs = config["num_epochs"] - epoch_length = config["epoch_length"] - total_num_iters = num_epochs * epoch_length - lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_num_iters, eta_min=0.0) - - # Setup trainer - trainer.setup( - config=config, - model=model, ema_model=ema_model, optimizer=optimizer, lr_scheduler=lr_scheduler, - sup_criterion=sup_criterion, unsup_criterion=unsup_criterion, - cta=cta, - ) - - # Setup handler to prepare data batches - @trainer.on(Events.ITERATION_STARTED) - def prepare_batch(e): - sup_batch = next(supervised_train_loader_iter) - unsup_batch = next(unsupervised_train_loader_iter) - cta_probe_batch = next(cta_probe_loader_iter) - e.state.batch = { - "sup_batch": utils.sup_prepare_batch(sup_batch, device, non_blocking=True), - "unsup_batch": ( - convert_tensor(unsup_batch["image"], device, non_blocking=True), - convert_tensor(unsup_batch["strong_aug"], device, non_blocking=True) - ), - "cta_probe_batch": ( - *utils.sup_prepare_batch(cta_probe_batch, device, non_blocking=True), - [utils.deserialize(p) for p in cta_probe_batch['policy']] - ) - } - sup_batch = unsup_batch = cta_probe_batch = None - - # Setup handler to update EMA model - @trainer.on(Events.ITERATION_COMPLETED, config["ema_decay"]) - def update_ema_model(ema_decay): - # EMA on parametes - for ema_param, param in zip(ema_model.parameters(), model.parameters()): - ema_param.data.mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay) - - # Setup handlers for debugging - if debug: - - @trainer.on(Events.STARTED | Events.ITERATION_COMPLETED(every=100)) - def log_weights_norms(_): - - if rank == 0: - wn = [] - ema_wn = [] - for ema_param, param in zip(ema_model.parameters(), model.parameters()): - wn.append(torch.mean(param.data)) - ema_wn.append(torch.mean(ema_param.data)) - - print("\n\nWeights norms") - print("\n- Raw model: {}".format(utils.to_list_str(torch.tensor(wn[:10] + wn[-10:])))) - print("- EMA model: {}\n".format(utils.to_list_str(torch.tensor(ema_wn[:10] + ema_wn[-10:])))) - - if rank == 0: - profiler = BasicTimeProfiler() - profiler.attach(trainer) - - @trainer.on(Events.ITERATION_COMPLETED(every=200)) - def log_profiling(_): - results = profiler.get_results() - profiler.print_results(results) - - # Setup validation engine - metrics = { - "accuracy": Accuracy(), - "precision": Precision(average=False), - "recall": Recall(average=False), - } - - evaluator = create_supervised_evaluator( - model, metrics, - prepare_batch=utils.sup_prepare_batch, device=device, non_blocking=True - ) - ema_evaluator = create_supervised_evaluator( - ema_model, metrics, - prepare_batch=utils.sup_prepare_batch, device=device, non_blocking=True - ) - - def log_results(epoch, max_epochs, metrics, ema_metrics): - msg1 = "\n".join(["\t{:16s}: {}".format(k, utils.to_list_str(v)) for k, v in metrics.items()]) - msg2 = "\n".join(["\t{:16s}: {}".format(k, utils.to_list_str(v)) for k, v in ema_metrics.items()]) - print("\nEpoch {}/{}\nRaw:\n{}\nEMA:\n{}\n".format(epoch, max_epochs, msg1, msg2)) - print(utils.stats(cta)) - - def run_evaluation(): - data_loader = test_loader - le = None - if dist_utils.is_tpu_distributed(): - le = len(test_loader) - data_loader = dist_utils.to_parallel_loader(test_loader) - - evaluator.run(data_loader, epoch_length=le) - ema_evaluator.run(data_loader, epoch_length=le) - if rank == 0: - log_results( - trainer.state.epoch, - trainer.state.max_epochs, - evaluator.state.metrics, - ema_evaluator.state.metrics - ) - - ev = Events.EPOCH_COMPLETED(every=config["validate_every"]) | Events.STARTED | Events.COMPLETED - trainer.add_event_handler(ev, run_evaluation) - - # setup TB logging - if rank == 0: - tb_logger = common.setup_tb_logging( - config["output_path"], - trainer, - optimizers=optimizer, - evaluators={"validation": evaluator, "ema validation": ema_evaluator}, - log_every_iters=1 - ) - - if config["display_iters"]: - ProgressBar(persist=False, desc="Test evaluation").attach(evaluator) - ProgressBar(persist=False, desc="Test EMA evaluation").attach(ema_evaluator) - - data = list(range(epoch_length)) - - resume_from = list(Path(config["output_path"]).rglob("training_checkpoint*.pt*")) - if len(resume_from) > 0: - # get latest - checkpoint_fp = max(resume_from, key=lambda p: p.stat().st_mtime) - assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(checkpoint_fp.as_posix()) - if rank == 0: - print("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix())) - checkpoint = torch.load(checkpoint_fp.as_posix()) - Checkpoint.load_objects(to_load=trainer.to_save, checkpoint=checkpoint) - - try: - trainer.run(data, epoch_length=epoch_length, max_epochs=config["num_epochs"] if not debug else 1) - except Exception as e: - import traceback - print(traceback.format_exc()) - - if rank == 0: - tb_logger.close() - - supervised_train_loader_iter = unsupervised_train_loader_iter = cta_probe_loader_iter = None - - -def worker_task(_, trainer, config): - - if config["device"] == "cuda": - assert torch.cuda.is_available() - - if config["distributed"]: - dist_utils.initialize() - # let each node print the info - if dist_utils.get_rank() == 0: - print("\nDistributed setting:") - print("\tworld size: {}".format(dist_utils.get_world_size())) - print("\trank: {}".format(dist_utils.get_rank())) - print("\n") - - if (not config["distributed"]) or (dist_utils.get_rank() == 0): - ds_id = config["num_train_samples_per_class"] * 10 - conf_hash = hashlib.md5(repr(config).encode("utf-8")).hexdigest() - prefix = "training" if not config["debug"] else "debug-training" - prefix += "-{}".format(config["model"]) - output_path = Path(config["output_path"]) / "{}-{}-{}".format(prefix, ds_id, conf_hash) - - if not output_path.exists(): - output_path.mkdir(parents=True) - - # dump config: - with open((output_path / "config.yaml"), "w") as h: - yaml.dump(config, h) - - output_path = output_path.as_posix() - print("Output path: {}".format(output_path)) - config["output_path"] = output_path - - try: - run(trainer, config) - except KeyboardInterrupt: - print("Catched KeyboardInterrupt -> exit") - except Exception as e: - if config["distributed"]: - dist_utils.finalize() - raise e - - if config["distributed"]: - dist_utils.finalize() - - -def main(trainer, config): - parser = argparse.ArgumentParser("Semi-Supervised Learning - FixMatch with CTA: Train WRN-28-2 on CIFAR10 dataset") - parser.add_argument( - "--params", - type=str, - help="Override default configuration with parameters: " - "data_path=/path/to/dataset;batch_size=64;num_workers=12 ...", - ) - parser.add_argument("--local_rank", type=int, help="Local process rank in distributed computation") - - args = parser.parse_args() - - local_rank = 0 - if args.local_rank is not None: - local_rank = args.local_rank - - config["local_rank"] = local_rank - - # Override config: - if args.params is not None: - for param in args.params.split(";"): - key, value = param.split("=") - if "/" not in value: - value = eval(value) - config[key] = value - - if config["local_rank"] == 0: - print("SSL Training of {} on CIFAR10@{}".format(config["model"], config["num_train_samples_per_class"] * 10)) - print("- PyTorch version: {}".format(torch.__version__)) - print("- Ignite version: {}".format(ignite.__version__)) - print("- CUDA version: {}".format(torch.version.cuda)) - - print("\n") - print("Configuration:") - for key, value in config.items(): - print("\t{}: {}".format(key, value)) - print("\n") - - # Download dataset - if dist_utils.is_gpu_distributed(): - if dist_utils.get_rank() == 0: - utils.CIFAR10(root=config["data_path"], train=True, download=True) - dist_utils.dist.barrier() - - if config["distributed"] and config["device"] == "xla": - assert dist_utils.has_xla_support - import torch_xla.distributed.xla_multiprocessing as xmp - # Spawns eight of the map functions, one for each of the eight cores on the Cloud TPU - # Note: Colab only supports start_method='fork' - xmp.spawn(worker_task, args=(trainer, config), nprocs=8, start_method='fork') - else: - worker_task(None, trainer, config) - - -class BaseTrainer(Engine): - - output_names = [] - - def __init__(self): - super(BaseTrainer, self).__init__(self.train_step) - self.config = self.model = self.ema_model = self.optimizer = None - self.lr_scheduler = self.sup_criterion = self.unsup_criterion = None - self.cta = self.to_save = None - - def setup(self, **kwargs): - for k, v in kwargs.items(): - if k != 'self' and not k.startswith('_'): - setattr(self, k, v) - self._setup_common_handlers() - - def _setup_common_handlers(self): - # Setup other common handlers for the trainer - debug = self.config["debug"] - - self.to_save = { - "model": self.model, - "ema_model": self.ema_model, - "optimizer": self.optimizer, - "trainer": self, - "lr_scheduler": self.lr_scheduler, - "cta": self.cta - } - - if self.config["with_nv_amp_level"] is not None: - from apex import amp - self.to_save["amp"] = amp - - common.setup_common_training_handlers( - self, - to_save=None if debug else self.to_save, - save_every_iters=self.config["checkpoint_every"], - output_path=self.config["output_path"], - output_names=self.output_names, - lr_scheduler=self.lr_scheduler, - with_pbar_on_iters=self.config["display_iters"], - log_every_iters=2 - ) - - def train_step(self, engine, batch): - raise NotImplementedError("This is the base class") diff --git a/config/dataflow/cifar10.yaml b/config/dataflow/cifar10.yaml new file mode 100644 index 0000000..26c36ce --- /dev/null +++ b/config/dataflow/cifar10.yaml @@ -0,0 +1,7 @@ +# @package _group_ +name: cifar10 + +data_path: "/tmp/cifar10" + +batch_size: 64 +num_workers: 12 \ No newline at end of file diff --git a/config/fixmatch.yaml b/config/fixmatch.yaml new file mode 100644 index 0000000..bc04420 --- /dev/null +++ b/config/fixmatch.yaml @@ -0,0 +1,43 @@ +hydra: + run: + dir: /tmp/output-fixmatch-cifar10-hydra/fixmatch/${now:%Y%m%d-%H%M%S} + job_logging: + handlers: + console: + level: WARN + root: + level: WARN + +name: fixmatch + +seed: 543 +debug: false + +# model name (from torchvision) to setup model to train. For Wide-Resnet, use "WRN-28-2" +model: "resnet18" +num_classes: 10 + +ema_decay: 0.999 + +defaults: + - dataflow: cifar10 + - solver: default + - ssl: cta_pseudo + + +solver: + unsupervised_criterion: + cls: torch.nn.CrossEntropyLoss + params: + reduction: 'none' + + +distributed: + # backend to use for distributed configuration. Possible values: None, "nccl", "xla-tpu", "gloo" etc. Default, None. + backend: null + # optional argument to setup number of processes per node. It is useful, when main python process is spawning training as child processes. + nproc_per_node: null + + +online_exp_tracking: + wandb: false \ No newline at end of file diff --git a/config/fully_supervised.yaml b/config/fully_supervised.yaml new file mode 100644 index 0000000..b4777d1 --- /dev/null +++ b/config/fully_supervised.yaml @@ -0,0 +1,36 @@ +hydra: + run: + dir: /tmp/output-fixmatch-cifar10-hydra/fully_supervised/${now:%Y%m%d-%H%M%S} + job_logging: + handlers: + console: + level: WARN + root: + level: WARN + +name: fully-supervised + +seed: 543 +debug: false + +# model name (from torchvision) to setup model to train. For Wide-Resnet, use "WRN-28-2" +model: "resnet18" +num_classes: 10 + +ema_decay: 0.999 + +defaults: + - dataflow: cifar10 + - solver: default + - ssl: full_sup + + +distributed: + # backend to use for distributed configuration. Possible values: None, "nccl", "xla-tpu", "gloo" etc. Default, None. + backend: null + # optional argument to setup number of processes per node. It is useful, when main python process is spawning training as child processes. + nproc_per_node: null + + +online_exp_tracking: + wandb: false \ No newline at end of file diff --git a/config/solver/default.yaml b/config/solver/default.yaml new file mode 100644 index 0000000..3b76262 --- /dev/null +++ b/config/solver/default.yaml @@ -0,0 +1,30 @@ +# @package _group_ + +num_epochs: 1024 + +epoch_length: 128 # epoch_length * num_epochs == 2 ** 20 + +checkpoint_every: 500 + +validate_every: 1 + +resume_from: null + +optimizer: + cls: torch.optim.SGD + params: + lr: 0.03 + momentum: 0.9 + weight_decay: 0.0001 + nesterov: false + + +supervised_criterion: + cls: torch.nn.CrossEntropyLoss + + +lr_scheduler: + cls: torch.optim.lr_scheduler.CosineAnnealingLR + params: + eta_min: 0.0 + T_max: null diff --git a/config/ssl/cta_pseudo.yaml b/config/ssl/cta_pseudo.yaml new file mode 100644 index 0000000..c56c30d --- /dev/null +++ b/config/ssl/cta_pseudo.yaml @@ -0,0 +1,11 @@ +# @package _group_ + +num_train_samples_per_class: 25 + +confidence_threshold: 0.95 + +lambda_u: 1.0 + +mu_ratio: 7 + +cta_update_every: 1 \ No newline at end of file diff --git a/config/ssl/full_sup.yaml b/config/ssl/full_sup.yaml new file mode 100644 index 0000000..c7764c4 --- /dev/null +++ b/config/ssl/full_sup.yaml @@ -0,0 +1,3 @@ +# @package _group_ + +num_train_samples_per_class: 25 diff --git a/configs.py b/configs.py deleted file mode 100644 index adda2b1..0000000 --- a/configs.py +++ /dev/null @@ -1,50 +0,0 @@ - -def get_ssl_config(): - return { - # SSL settings - "num_train_samples_per_class": 25, - "mu_ratio": 7, - "ema_decay": 0.999, - } - - -def get_backend_config(): - return { - "device": "cuda", # possible values "cuda" or "xla" - "distributed": False, - - # AMP - "with_nv_amp_level": None, # if "O1" or "O2" -> train with apex/amp, otherwise fp32 (None) - } - - -def get_default_config(): - batch_size = 64 - - config = { - "seed": 12, - "data_path": "/tmp/cifar10", - "output_path": "/tmp/output-fixmatch-cifar10", - "model": "WRN-28-2", - "momentum": 0.9, - "weight_decay": 0.0005, - "batch_size": batch_size, - "num_workers": 12, - "num_epochs": 1024, - "epoch_length": 2 ** 16 // batch_size, # epoch_length * num_epochs == 2 ** 20 - "learning_rate": 0.03, - "validate_every": 1, - - # Logging: - "display_iters": True, - "checkpoint_every": 200, - "debug": False, - - # online platform logging: - "online_logging": None # "Neptune" or "WandB" - } - - config.update(get_ssl_config()) - config.update(get_backend_config()) - - return config diff --git a/ctaugment/__init__.py b/ctaugment/__init__.py new file mode 100644 index 0000000..f94ca91 --- /dev/null +++ b/ctaugment/__init__.py @@ -0,0 +1,62 @@ +import json +from collections import OrderedDict + +from ctaugment.ctaugment import * + + +class StorableCTAugment(CTAugment): + def load_state_dict(self, state): + for k in ["decay", "depth", "th", "rates"]: + assert k in state, "{} not in {}".format(k, state.keys()) + setattr(self, k, state[k]) + + def state_dict(self): + return OrderedDict( + [(k, getattr(self, k)) for k in ["decay", "depth", "th", "rates"]] + ) + + +def get_default_cta(): + return StorableCTAugment() + + +def cta_apply(pil_img, ops): + if ops is None: + return pil_img + for op, args in ops: + pil_img = OPS[op].f(pil_img, *args) + return pil_img + + +def deserialize(policy_str): + return [OP(f=x[0], bins=x[1]) for x in json.loads(policy_str)] + + +def stats(cta): + return "\n".join( + "%-16s %s" + % ( + k, + " / ".join( + " ".join("%.2f" % x for x in cta.rate_to_p(rate)) + for rate in cta.rates[k] + ), + ) + for k in sorted(OPS.keys()) + ) + + +def interleave(x, batch, inverse=False): + """ + TF code + def interleave(x, batch): + s = x.get_shape().as_list() + return tf.reshape(tf.transpose(tf.reshape(x, [-1, batch] + s[1:]), [1, 0] + list(range(2, 1+len(s)))), [-1] + s[1:]) + """ + shape = x.shape + axes = [batch, -1] if inverse else [-1, batch] + return x.reshape(*axes, *shape[1:]).transpose(0, 1).reshape(-1, *shape[1:]) + + +def deinterleave(x, batch): + return interleave(x, batch, inverse=True) diff --git a/ctaugment.py b/ctaugment/ctaugment.py similarity index 87% rename from ctaugment.py rename to ctaugment/ctaugment.py index 0fa36b1..d8fe07c 100644 --- a/ctaugment.py +++ b/ctaugment/ctaugment.py @@ -22,8 +22,8 @@ OPS = {} -OP = namedtuple('OP', ('f', 'bins')) -Sample = namedtuple('Sample', ('train', 'probe')) +OP = namedtuple("OP", ("f", "bins")) +Sample = namedtuple("Sample", ("train", "probe")) def register(*bins): @@ -41,7 +41,7 @@ def __init__(self, depth=2, th=0.85, decay=0.99): self.th = th self.rates = {} for k, op in OPS.items(): - self.rates[k] = tuple([np.ones(x, 'f') for x in op.bins]) + self.rates[k] = tuple([np.ones(x, "f") for x in op.bins]) def rate_to_p(self, rate): p = rate + (1 - self.decay) # Avoid to have all zero. @@ -78,9 +78,17 @@ def update_rates(self, policy, proximity): rate[p] = rate[p] * self.decay + proximity * (1 - self.decay) def stats(self): - return '\n'.join('%-16s %s' % (k, ' / '.join(' '.join('%.2f' % x for x in self.rate_to_p(rate)) - for rate in self.rates[k])) - for k in sorted(OPS.keys())) + return "\n".join( + "%-16s %s" + % ( + k, + " / ".join( + " ".join("%.2f" % x for x in self.rate_to_p(rate)) + for rate in self.rates[k] + ), + ) + for k in sorted(OPS.keys()) + ) def _enhance(x, op, level): @@ -128,7 +136,10 @@ def cutout(x, level): height_loc = np.random.randint(low=0, high=img_height) width_loc = np.random.randint(low=0, high=img_width) upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2)) - lower_coord = (min(img_height, height_loc + size // 2), min(img_width, width_loc + size // 2)) + lower_coord = ( + min(img_height, height_loc + size // 2), + min(img_width, width_loc + size // 2), + ) pixels = x.load() # create the pixel map for i in range(upper_coord[0], lower_coord[0]): # for every col: for j in range(upper_coord[1], lower_coord[1]): # For every row @@ -162,7 +173,14 @@ def rescale(x, scale, method): s = x.size scale *= 0.25 crop = (scale * s[0], scale * s[1], s[0] * (1 - scale), s[1] * (1 - scale)) - methods = (Image.ANTIALIAS, Image.BICUBIC, Image.BILINEAR, Image.BOX, Image.HAMMING, Image.NEAREST) + methods = ( + Image.ANTIALIAS, + Image.BICUBIC, + Image.BILINEAR, + Image.BOX, + Image.HAMMING, + Image.NEAREST, + ) method = methods[int(method * 5.99)] return x.crop(crop).resize(x.size, method) diff --git a/dataflow/__init__.py b/dataflow/__init__.py new file mode 100644 index 0000000..c6717b8 --- /dev/null +++ b/dataflow/__init__.py @@ -0,0 +1,108 @@ +from functools import partial + +from torch.utils.data import Dataset + +from ignite.utils import convert_tensor + + +class TransformedDataset(Dataset): + def __init__(self, dataset, transforms): + self.dataset = dataset + self.transforms = transforms + + def __getitem__(self, i): + dp = self.dataset[i] + return self.transforms(dp) + + def __len__(self): + return len(self.dataset) + + +def sup_prepare_batch(batch, device, non_blocking): + x = convert_tensor(batch["image"], device, non_blocking) + y = convert_tensor(batch["target"], device, non_blocking) + return x, y + + +def cycle(dataloader): + while True: + for b in dataloader: + yield b + + +def get_supervised_train_loader( + dataset_name, root, num_train_samples_per_class, download=True, **dataloader_kwargs +): + if dataset_name == "cifar10": + from dataflow.cifar10 import ( + get_supervised_trainset, + get_supervised_train_loader, + weak_transforms, + ) + + train_dataset = get_supervised_trainset( + root, + num_train_samples_per_class=num_train_samples_per_class, + download=download, + ) + + return get_supervised_train_loader(train_dataset, **dataloader_kwargs) + + else: + raise ValueError("Unhandled dataset: {}".format(dataset_name)) + + +def get_test_loader(dataset_name, root, download=True, **dataloader_kwargs): + if dataset_name == "cifar10": + from dataflow.cifar10 import get_test_loader + + return get_test_loader(root=root, download=download, **dataloader_kwargs) + + else: + raise ValueError("Unhandled dataset: {}".format(dataset_name)) + + +def get_unsupervised_train_loader( + dataset_name, root, cta, download=True, **dataloader_kwargs +): + if dataset_name == "cifar10": + from dataflow import cifar10 + + full_train_dataset = cifar10.get_supervised_trainset( + root, num_train_samples_per_class=None, download=download + ) + + strong_transforms = partial(cifar10.cta_image_transforms, cta=cta) + + return cifar10.get_unsupervised_train_loader( + full_train_dataset, + transforms_weak=cifar10.weak_transforms, + transforms_strong=strong_transforms, + **dataloader_kwargs + ) + + else: + raise ValueError("Unhandled dataset: {}".format(dataset_name)) + + +def get_cta_probe_loader( + dataset_name, + root, + num_train_samples_per_class, + cta, + download=True, + **dataloader_kwargs +): + if dataset_name == "cifar10": + from dataflow.cifar10 import get_supervised_trainset, get_cta_probe_loader + + train_dataset = get_supervised_trainset( + root, + num_train_samples_per_class=num_train_samples_per_class, + download=download, + ) + + return get_cta_probe_loader(train_dataset, cta=cta, **dataloader_kwargs) + + else: + raise ValueError("Unhandled dataset: {}".format(dataset_name)) diff --git a/dataflow/cifar10.py b/dataflow/cifar10.py new file mode 100644 index 0000000..da38a86 --- /dev/null +++ b/dataflow/cifar10.py @@ -0,0 +1,410 @@ +import json +from functools import partial + +import numpy as np + +from torch.utils.data import Subset + +from torchvision import transforms as T +from torchvision.datasets.cifar import CIFAR10 + +import ignite.distributed as idist + +from dataflow import TransformedDataset +from ctaugment import cta_apply + + +weak_transforms = T.Compose( + [ + T.Pad(4), + T.RandomCrop(32), + T.RandomHorizontalFlip(), + T.ToTensor(), + T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.25, 0.25, 0.25)), + ] +) + +test_transforms = T.Compose( + [T.ToTensor(), T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.25, 0.25, 0.25))] +) + +cutout_image_transforms = T.Compose( + [ + T.ToTensor(), + T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.25, 0.25, 0.25)), + T.RandomErasing(scale=(0.02, 0.15)), + ] +) + + +def get_supervised_trainset(root, num_train_samples_per_class=25, download=True): + + if num_train_samples_per_class == 25: + return get_supervised_trainset_0_250(root, download=download) + + num_classes = 10 + full_train_dataset = CIFAR10(root, train=True, download=download) + + if num_train_samples_per_class is None: + return full_train_dataset + + supervised_train_indices = [] + counter = [0] * num_classes + + np.random.seed(num_train_samples_per_class) + + indices = list(range(len(full_train_dataset))) + random_indices = np.random.permutation(indices) + + for i in random_indices: + dp = full_train_dataset[i] + if len(supervised_train_indices) >= num_classes * num_train_samples_per_class: + break + if counter[dp[1]] < num_train_samples_per_class: + counter[dp[1]] += 1 + supervised_train_indices.append(i) + + return Subset(full_train_dataset, supervised_train_indices) + + +def get_supervised_trainset_0_250(root, download=True): + full_train_dataset = CIFAR10(root, train=True, download=download) + + supervised_train_indices = [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 124, + 125, + 126, + 127, + 128, + 129, + 130, + 131, + 132, + 133, + 134, + 135, + 136, + 137, + 138, + 139, + 140, + 141, + 142, + 143, + 144, + 145, + 146, + 147, + 148, + 149, + 150, + 151, + 152, + 153, + 154, + 155, + 156, + 157, + 158, + 159, + 160, + 161, + 162, + 163, + 164, + 165, + 166, + 167, + 169, + 170, + 171, + 172, + 173, + 174, + 175, + 177, + 178, + 179, + 180, + 181, + 182, + 183, + 185, + 186, + 187, + 188, + 189, + 190, + 191, + 192, + 193, + 194, + 195, + 196, + 197, + 198, + 199, + 200, + 202, + 203, + 204, + 205, + 207, + 209, + 210, + 211, + 213, + 215, + 216, + 217, + 218, + 220, + 221, + 222, + 223, + 224, + 228, + 229, + 230, + 231, + 233, + 237, + 239, + 240, + 241, + 244, + 246, + 247, + 252, + 254, + 256, + 259, + 260, + 263, + 264, + 268, + 271, + 272, + 276, + 277, + 279, + 280, + 281, + 284, + 285, + 290, + 293, + 296, + 308, + 317, + ] + return Subset(full_train_dataset, supervised_train_indices) + + +def get_supervised_train_loader( + supervised_train_dataset, transforms=weak_transforms, **dataloader_kwargs +): + dataloader_kwargs["pin_memory"] = "cuda" in idist.device().type + dataloader_kwargs["drop_last"] = True + dataloader_kwargs["shuffle"] = dataloader_kwargs.get("sampler", None) is None + + supervised_train_loader = idist.auto_dataloader( + TransformedDataset( + supervised_train_dataset, + transforms=lambda d: {"image": transforms(d[0]), "target": d[1]}, + ), + **dataloader_kwargs + ) + return supervised_train_loader + + +def get_test_loader( + root, transforms=test_transforms, download=True, **dataloader_kwargs +): + full_test_dataset = CIFAR10(root, train=False, download=download) + + dataloader_kwargs["pin_memory"] = "cuda" in idist.device().type + dataloader_kwargs["drop_last"] = False + dataloader_kwargs["shuffle"] = False + + test_loader = idist.auto_dataloader( + TransformedDataset( + full_test_dataset, + transforms=lambda dp: {"image": transforms(dp[0]), "target": dp[1]}, + ), + **dataloader_kwargs + ) + return test_loader + + +def cta_image_transforms(pil_img, cta, transform=cutout_image_transforms): + policy = cta.policy(probe=False) + pil_img = cta_apply(pil_img, policy) + return transform(pil_img) + + +def cta_probe_transforms(dp, cta, image_transforms=cutout_image_transforms): + policy = cta.policy(probe=True) + probe = cta_apply(dp[0], policy) + probe = image_transforms(probe) + return {"image": probe, "target": dp[1], "policy": json.dumps(policy)} + + +def get_cta_probe_loader(supervised_train_dataset, cta, **dataloader_kwargs): + dataloader_kwargs["pin_memory"] = "cuda" in idist.device().type + dataloader_kwargs["drop_last"] = False + dataloader_kwargs["shuffle"] = dataloader_kwargs.get("sampler", None) is None + + cta_probe_loader = idist.auto_dataloader( + TransformedDataset( + supervised_train_dataset, transforms=partial(cta_probe_transforms, cta=cta) + ), + **dataloader_kwargs + ) + + return cta_probe_loader + + +def get_unsupervised_train_loader( + raw_dataset, transforms_weak, transforms_strong, **dataloader_kwargs +): + unsupervised_train_dataset = TransformedDataset( + raw_dataset, + transforms=lambda dp: { + "image": transforms_weak(dp[0]), + "strong_aug": transforms_strong(dp[0]), + }, + ) + + dataloader_kwargs["drop_last"] = True + dataloader_kwargs["pin_memory"] = "cuda" in idist.device().type + dataloader_kwargs["shuffle"] = dataloader_kwargs.get("sampler", None) is None + + unsupervised_train_loader = idist.auto_dataloader( + unsupervised_train_dataset, **dataloader_kwargs + ) + return unsupervised_train_loader diff --git a/dist_utils.py b/dist_utils.py deleted file mode 100644 index 1c5f19b..0000000 --- a/dist_utils.py +++ /dev/null @@ -1,105 +0,0 @@ -import numbers - -import torch -import torch.distributed as dist - -try: - import torch_xla.core.xla_model as xm - from torch_xla.distributed.parallel_loader import ParallelLoader - has_xla_support = True -except ImportError: - has_xla_support = False - - -def is_gpu_distributed(): - return dist.is_available() and dist.is_initialized() - - -def is_tpu_distributed(): - return has_xla_support and xm.xrt_world_size() > 1 - - -def get_world_size(): - if is_gpu_distributed(): - return dist.get_world_size() - elif is_tpu_distributed(): - return xm.xrt_world_size() - else: - return 1 - - -def get_rank(): - if is_gpu_distributed(): - return dist.get_rank() - elif is_tpu_distributed(): - return xm.get_ordinal() - else: - return 0 - - -def get_num_proc_per_node(): - if is_gpu_distributed(): - return torch.cuda.device_count() - elif is_tpu_distributed(): - return len(xm.get_xla_supported_devices()) - else: - return 1 - - -def device(default_value): - if is_gpu_distributed(): - return torch.cuda.current_device() - elif is_tpu_distributed(): - return xm.xla_device() - return default_value - - -def initialize(): - if has_xla_support: - xm.rendezvous('init') - else: - torch.backends.cudnn.benchmark = True - dist.init_process_group("nccl", init_method="env://") - - -def finalize(): - if not has_xla_support: - dist.destroy_process_group() - - -def _tpu_sync_all_reduce(self, tensor): - tensor_to_number = False - if isinstance(tensor, numbers.Number): - tensor = torch.tensor(tensor, device=self._device, dtype=torch.float) - tensor_to_number = True - - if isinstance(tensor, torch.Tensor): - # check if the tensor is at specified device - if tensor.device != self._device: - tensor = tensor.to(self._device) - else: - raise TypeError("Unhandled input type {}".format(type(tensor))) - - # synchronize and reduce - xm.all_reduce("sum", [tensor, ]) - - if tensor_to_number: - return tensor.item() - return tensor - - -def to_parallel_loader(data_loader): - device = xm.xla_device() - data_loader = ParallelLoader(data_loader, [device, ]) - data_loader = data_loader.per_device_loader(device) - return data_loader - - -def _temporary_ignite_metrics_patch(): - # until merged https://github.com/pytorch/ignite/issues/992 - if is_tpu_distributed(): - from ignite.metrics import Metric - Metric._sync_all_reduce = _tpu_sync_all_reduce - - -_temporary_ignite_metrics_patch() diff --git a/main_fixmatch.py b/main_fixmatch.py index fe3e4b6..1ef98b7 100644 --- a/main_fixmatch.py +++ b/main_fixmatch.py @@ -1,12 +1,16 @@ import torch -import torch.distributed as dist +import ignite.distributed as idist from ignite.engine import Events +from ignite.utils import manual_seed, setup_logger + +import hydra +from hydra.utils import instantiate +from omegaconf import DictConfig import utils -from base_train import main, BaseTrainer -from configs import get_default_config -from ctaugment import OPS +import trainers +from ctaugment import get_default_cta, OPS, interleave, deinterleave sorted_op_names = sorted(list(OPS.keys())) @@ -17,7 +21,7 @@ def pack_as_tensor(k, bins, error, size=5, pad_value=-555.0): out[0] = sorted_op_names.index(k) le = len(bins) out[1] = le - out[2:2 + le] = torch.tensor(bins).to(error) + out[2 : 2 + le] = torch.tensor(bins).to(error) out[2 + le] = error return out @@ -25,113 +29,186 @@ def pack_as_tensor(k, bins, error, size=5, pad_value=-555.0): def unpack_from_tensor(t): k_index = int(t[0].item()) le = int(t[1].item()) - bins = t[2:2 + le].tolist() + bins = t[2 : 2 + le].tolist() error = t[2 + le].item() return sorted_op_names[k_index], bins, error -class FixMatchTrainer(BaseTrainer): +def training(local_rank, cfg): - output_names = ["total_loss", "sup_loss", "unsup_loss", "mask"] + logger = setup_logger("FixMatch Training", distributed_rank=idist.get_rank()) + + if local_rank == 0: + logger.info(cfg.pretty()) + + rank = idist.get_rank() + manual_seed(cfg.seed + rank) + device = idist.device() + + model, ema_model, optimizer, sup_criterion, lr_scheduler = utils.initialize(cfg) + + unsup_criterion = instantiate(cfg.solver.unsupervised_criterion) + + cta = get_default_cta() - def train_step(self, engine, batch): - self.model.train() - self.optimizer.zero_grad() + ( + supervised_train_loader, + test_loader, + unsup_train_loader, + cta_probe_loader, + ) = utils.get_dataflow(cfg, cta=cta, with_unsup=True) - x, y = batch["sup_batch"] - weak_x, strong_x = batch["unsup_batch"] + def train_step(engine, batch): + model.train() + optimizer.zero_grad() + + x, y = batch["sup_batch"]["image"], batch["sup_batch"]["target"] + if x.device != device: + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + + weak_x, strong_x = ( + batch["unsup_batch"]["image"], + batch["unsup_batch"]["strong_aug"], + ) + if weak_x.device != device: + weak_x = weak_x.to(device, non_blocking=True) + strong_x = strong_x.to(device, non_blocking=True) # according to TF code: single forward pass on concat data: [x, weak_x, strong_x] - le = 2 * self.config["mu_ratio"] + 1 - x_cat = utils.interleave(torch.cat([x, weak_x, strong_x], dim=0), le) - y_pred_cat = self.model(x_cat) - y_pred_cat = utils.deinterleave(y_pred_cat, le) + le = 2 * engine.state.mu_ratio + 1 + # Why interleave: https://github.com/google-research/fixmatch/issues/20#issuecomment-613010277 + # We need to interleave due to multiple-GPU batch norm issues. Let's say we have to GPUs, and our batch is + # comprised of labeled (L) and unlabeled (U) images. Let's use a batch size of 2 for making easier visually + # in my following example. + # + # - Without interleaving, we have a batch LLUUUUUU...U (there are 14 U). When the batch is split to be passed + # to both GPUs, we'll have two batches LLUUUUUU and UUUUUUUU. Note that all labeled examples ended up in batch1 + # sent to GPU1. The problem here is that batch norm will be computed per batch and the moments will lack + # consistency between batches. + # + # - With interleaving, by contrast, the two batches will be LUUUUUUU and LUUUUUUU. As you can notice the + # batches have the same distribution of labeled and unlabeled samples and will therefore have more consistent + # moments. + # + x_cat = interleave(torch.cat([x, weak_x, strong_x], dim=0), le) + y_pred_cat = model(x_cat) + y_pred_cat = deinterleave(y_pred_cat, le) idx1 = len(x) idx2 = idx1 + len(weak_x) y_pred = y_pred_cat[:idx1, ...] y_weak_preds = y_pred_cat[idx1:idx2, ...] # logits_weak - y_strong_preds = y_pred_cat[idx2:, ...] # logits_strong + y_strong_preds = y_pred_cat[idx2:, ...] # logits_strong # supervised learning: - sup_loss = self.sup_criterion(y_pred, y) + sup_loss = sup_criterion(y_pred, y) # unsupervised learning: y_weak_probas = torch.softmax(y_weak_preds, dim=1).detach() y_pseudo = y_weak_probas.argmax(dim=1) max_y_weak_probas, _ = y_weak_probas.max(dim=1) - unsup_loss_mask = (max_y_weak_probas >= self.confidence_threshold).float() - unsup_loss = (self.unsup_criterion(y_strong_preds, y_pseudo) * unsup_loss_mask).mean() + unsup_loss_mask = ( + max_y_weak_probas >= engine.state.confidence_threshold + ).float() + unsup_loss = ( + unsup_criterion(y_strong_preds, y_pseudo) * unsup_loss_mask + ).mean() - total_loss = sup_loss + self.lambda_u * unsup_loss + total_loss = sup_loss + engine.state.lambda_u * unsup_loss - if self.config["with_nv_amp_level"] is not None: - from apex import amp - with amp.scale_loss(total_loss, self.optimizer) as scaled_loss: - scaled_loss.backward() - else: - total_loss.backward() + total_loss.backward() - self.optimizer.step() + optimizer.step() return { "total_loss": total_loss.item(), "sup_loss": sup_loss.item(), "unsup_loss": unsup_loss.item(), - "mask": unsup_loss_mask.mean().item() # this should not be averaged for DDP + "mask": unsup_loss_mask.mean().item(), # this should not be averaged for DDP } - def setup(self, **kwargs): - super(FixMatchTrainer, self).setup(**kwargs) - self.confidence_threshold = self.config["confidence_threshold"] - self.lambda_u = self.config["lambda_u"] - self.add_event_handler(Events.ITERATION_COMPLETED, self.update_cta_rates) - self.distributed = dist.is_available() and dist.is_initialized() + output_names = ["total_loss", "sup_loss", "unsup_loss", "mask"] - def update_cta_rates(self): - x, y, policies = self.state.batch["cta_probe_batch"] - self.ema_model.eval() + trainer = trainers.create_trainer( + train_step, + output_names=output_names, + model=model, + ema_model=ema_model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + supervised_train_loader=supervised_train_loader, + test_loader=test_loader, + cfg=cfg, + logger=logger, + cta=cta, + unsup_train_loader=unsup_train_loader, + cta_probe_loader=cta_probe_loader, + ) + + trainer.state.confidence_threshold = cfg.ssl.confidence_threshold + trainer.state.lambda_u = cfg.ssl.lambda_u + trainer.state.mu_ratio = cfg.ssl.mu_ratio + + distributed = idist.get_world_size() > 1 + + @trainer.on(Events.ITERATION_COMPLETED(every=cfg.ssl.cta_update_every)) + def update_cta_rates(): + batch = trainer.state.batch + x, y = batch["cta_probe_batch"]["image"], batch["cta_probe_batch"]["target"] + if x.device != device: + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + + policies = batch["cta_probe_batch"]["policy"] + + ema_model.eval() with torch.no_grad(): - y_pred = self.ema_model(x) + y_pred = ema_model(x) y_probas = torch.softmax(y_pred, dim=1) # (N, C) - if not self.distributed: - for y_proba, t, policy in zip(y_probas, y, policies): + if distributed: + for y_proba, t, policy in zip(y_probas, y, policies): error = y_proba error[t] -= 1 error = torch.abs(error).sum() - self.cta.update_rates(policy, 1.0 - 0.5 * error.item()) + cta.update_rates(policy, 1.0 - 0.5 * error.item()) else: error_per_op = [] for y_proba, t, policy in zip(y_probas, y, policies): error = y_proba error[t] -= 1 error = torch.abs(error).sum() - for k, bins in policy: + for k, bins in policy: error_per_op.append(pack_as_tensor(k, bins, error)) error_per_op = torch.stack(error_per_op) - # all gather - tensor_list = [ - torch.empty_like(error_per_op) - for _ in range(dist.get_world_size()) - ] - dist.all_gather(tensor_list, error_per_op) - tensor_list = torch.cat(tensor_list, dim=0) + # all gather + tensor_list = idist.all_gather(error_per_op) # update cta rates for t in tensor_list: - k, bins, error = unpack_from_tensor(t) - self.cta.update_rates([(k, bins), ], 1.0 - 0.5 * error) + k, bins, error = unpack_from_tensor(t) + cta.update_rates([(k, bins),], 1.0 - 0.5 * error) + + epoch_length = cfg.solver.epoch_length + num_epochs = cfg.solver.num_epochs if not cfg.debug else 2 + try: + trainer.run( + supervised_train_loader, epoch_length=epoch_length, max_epochs=num_epochs + ) + except Exception as e: + import traceback + + print(traceback.format_exc()) + +@hydra.main(config_path="config", config_name="fixmatch") +def main(cfg: DictConfig) -> None: -def get_fixmatch_config(): - config = get_default_config() - config.update({ - # FixMatch settings - "confidence_threshold": 0.95, - "lambda_u": 1.0, - }) - return config + with idist.Parallel( + backend=cfg.distributed.backend, nproc_per_node=cfg.distributed.nproc_per_node + ) as parallel: + parallel.run(training, cfg) if __name__ == "__main__": - main(FixMatchTrainer(), get_fixmatch_config()) + main() diff --git a/main_fixmatch_2steps.py b/main_fixmatch_2steps.py deleted file mode 100644 index 24f0174..0000000 --- a/main_fixmatch_2steps.py +++ /dev/null @@ -1,73 +0,0 @@ - -import torch - -from ignite.engine import Events -from ignite.utils import convert_tensor - -import utils -from base_main import main, get_default_config -from main_fixmatch import FixMatchTrainer - - -def get_config(): - config = get_default_config() - config["num_sup_substeps"] = 2 - config["num_unsup_substeps"] = 1 - return config - - -class FixMatchTwoStepsTrainer(FixMatchTrainer): - - def train_step(self, *args, **kwargs): - self.model.train() - self.optimizer.zero_grad() - - # supervised part - total_loss = 0 - for _ in range(config["num_sup_substeps"]): - sup_batch = next(self.supervised_train_loader_iter) - x, y = utils.sup_prepare_batch(sup_batch, self.device, non_blocking=True) - y_pred = self.model(x) - sup_loss = self.sup_criterion(y_pred, y) - - if self.config["with_nv_amp_level"] is not None: - from apex import amp - with amp.scale_loss(sup_loss, self.optimizer) as scaled_loss: - scaled_loss.backward() - else: - sup_loss.backward() - total_loss += sup_loss - - # pseudo-labeling - for _ in range(config["num_unsup_substeps"]): - unsup_batch = next(self.unsupervised_train_loader_iter) - weak_x = convert_tensor(unsup_batch["image"], self.device, non_blocking=True) - strong_x = convert_tensor(unsup_batch["strong_aug"], self.device, non_blocking=True) - - y_strong_preds = self.model(strong_x) - y_weak_preds = self.model(weak_x).detach() - y_pseudo = y_weak_preds.argmax(dim=1) - y_weak_probas = torch.softmax(y_weak_preds, dim=1) - max_y_weak_probas, _ = y_weak_probas.max(dim=1) - unsup_loss_mask = (max_y_weak_probas > self.confidence_threshold).float() - unsup_loss = (self.unsup_criterion(y_strong_preds, y_pseudo) * unsup_loss_mask).mean() - if self.config["with_nv_amp_level"] is not None: - from apex import amp - with amp.scale_loss(unsup_loss, self.optimizer) as scaled_loss: - scaled_loss.backward() - else: - unsup_loss.backward() - total_loss += self.lambda_u * unsup_loss - - self.optimizer.step() - - return { - "total_loss": total_loss, - "sup_loss": sup_loss, - "unsup_loss": unsup_loss, - "mask": unsup_loss_mask.mean() - } - - -if __name__ == "__main__": - main(FixMatchTwoStepsTrainer(), get_config()) diff --git a/main_fully_supervised.py b/main_fully_supervised.py index 276208b..4ca0348 100644 --- a/main_fully_supervised.py +++ b/main_fully_supervised.py @@ -1,39 +1,83 @@ -from base_train import main, BaseTrainer -from configs import get_default_config -import dist_utils +import hydra +from omegaconf import DictConfig +import ignite.distributed as idist +from ignite.utils import manual_seed, setup_logger -class FullySupervisedTrainer(BaseTrainer): +import utils +import trainers - output_names = ["sup_loss", ] - def train_step(self, engine, batch): - self.model.train() - self.optimizer.zero_grad() +def training(local_rank, cfg): - x, y = batch["sup_batch"] + logger = setup_logger( + "Fully-Supervised Training", distributed_rank=idist.get_rank() + ) - y_pred = self.model(x) + if local_rank == 0: + logger.info(cfg.pretty()) - # supervised learning: - sup_loss = self.sup_criterion(y_pred, y) + rank = idist.get_rank() + manual_seed(cfg.seed + rank) + device = idist.device() - if self.config["with_nv_amp_level"] is not None: - from apex import amp - with amp.scale_loss(sup_loss, self.optimizer) as scaled_loss: - scaled_loss.backward() - else: - sup_loss.backward() + model, ema_model, optimizer, sup_criterion, lr_scheduler = utils.initialize(cfg) - if dist_utils.is_tpu_distributed(): - dist_utils.xm.optimizer_step(self.optimizer) - else: - self.optimizer.step() + supervised_train_loader, test_loader, *_ = utils.get_dataflow(cfg) + + def train_step(engine, batch): + model.train() + optimizer.zero_grad() + + x = batch["sup_batch"]["image"] + y = batch["sup_batch"]["target"] + if x.device != device: + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + + y_pred = model(x) + sup_loss = sup_criterion(y_pred, y) + sup_loss.backward() + + optimizer.step() return { "sup_loss": sup_loss.item(), } + trainer = trainers.create_trainer( + train_step, + output_names=["sup_loss",], + model=model, + ema_model=ema_model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + supervised_train_loader=supervised_train_loader, + test_loader=test_loader, + cfg=cfg, + logger=logger, + ) + + epoch_length = cfg.solver.epoch_length + num_epochs = cfg.solver.num_epochs if not cfg.debug else 2 + try: + trainer.run( + supervised_train_loader, epoch_length=epoch_length, max_epochs=num_epochs + ) + except Exception as e: + import traceback + + print(traceback.format_exc()) + + +@hydra.main(config_path="config", config_name="fully_supervised") +def main(cfg: DictConfig) -> None: + + with idist.Parallel( + backend=cfg.distributed.backend, nproc_per_node=cfg.distributed.nproc_per_node + ) as parallel: + parallel.run(training, cfg) + if __name__ == "__main__": - main(FullySupervisedTrainer(), get_default_config()) + main() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..4f52c9e --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,29 @@ +import torch.nn as nn + +from torchvision import models as tv_models + +from models.wrn import WideResNet + + +def setup_ema(ema_model, ref_model): + ema_model.load_state_dict(ref_model.state_dict()) + for param in ema_model.parameters(): + param.detach_() + # set EMA model's BN buffers as base model BN buffers: + for m1, m2 in zip(ref_model.modules(), ema_model.modules()): + if isinstance(m1, nn.BatchNorm2d) and isinstance(m2, nn.BatchNorm2d): + m2.running_mean = m1.running_mean + m2.running_var = m1.running_var + + +def setup_model(name, num_classes): + if name == "WRN-28-2": + model = WideResNet(num_classes=num_classes) + else: + if name in tv_models.__dict__: + fn = tv_models.__dict__[name] + else: + raise RuntimeError("Unknown model name {}".format(name)) + model = fn(num_classes=num_classes) + + return model diff --git a/wrn.py b/models/wrn.py similarity index 79% rename from wrn.py rename to models/wrn.py index 6afba1d..7ce1fe5 100644 --- a/wrn.py +++ b/models/wrn.py @@ -6,20 +6,27 @@ class Residual(nn.Module): - - def __init__(self, in_channels, out_channels, stride, activate_before_residual=False): + def __init__( + self, in_channels, out_channels, stride, activate_before_residual=False + ): super().__init__() self.bn1 = nn.BatchNorm2d(in_channels, momentum=0.001) self.leaky_relu = nn.LeakyReLU(0.1, inplace=True) - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=stride, padding=1 + ) self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.001) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) self.activate_before_residual = activate_before_residual if in_channels == out_channels: self.skip = nn.Identity() else: - self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0) + self.skip = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=stride, padding=0 + ) def forward(self, x0): x = self.leaky_relu(self.bn1(x0)) @@ -32,7 +39,6 @@ def forward(self, x0): class WideResNet(nn.Module): - def __init__(self, num_classes=10, filters=32, repeat=4): super().__init__() self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1) @@ -54,7 +60,7 @@ def __init__(self, num_classes=10, filters=32, repeat=4): self.bn = nn.BatchNorm2d(filters * 4, momentum=0.001) self.reduce = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(filters * 4, num_classes) - + self.init_weights() def forward(self, x): @@ -71,7 +77,9 @@ def init_weights(self): if isinstance(m, nn.Conv2d): nn.init.normal_( m.weight, - std=torch.tensor(0.5 * m.kernel_size[0] * m.kernel_size[0] * m.out_channels).rsqrt() + std=torch.tensor( + 0.5 * m.kernel_size[0] * m.kernel_size[0] * m.out_channels + ).rsqrt(), ) elif isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) diff --git a/tests/test_ops.py b/tests/test_ops.py index b1993e8..c56d0a2 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -1,4 +1,3 @@ -import pytest import numpy as np import torch @@ -8,21 +7,17 @@ def np_interleave(x, batch, inverse=False): # Code from utils.interleave with tf -> np - # + # # def interleave(x, batch): # s = x.get_shape().as_list() - # return tf.reshape(tf.transpose(tf.reshape(x, [-1, batch] + s[1:]), [1, 0] + list(range(2, 1+len(s)))), [-1] + s[1:]) + # return tf.reshape( + # tf.transpose(tf.reshape(x, [-1, batch] + s[1:]), [1, 0] + list(range(2, 1+len(s)))), [-1] + s[1:] + # ) s = list(x.shape) axes = [-1, batch] if not inverse else [batch, -1] return np.reshape( - np.transpose( - np.reshape( - x, - axes + s[1:] - ), - [1, 0] + list(range(2, 1 + len(s))) - ), - [-1] + s[1:] + np.transpose(np.reshape(x, axes + s[1:]), [1, 0] + list(range(2, 1 + len(s)))), + [-1] + s[1:], ) diff --git a/trainers/__init__.py b/trainers/__init__.py new file mode 100644 index 0000000..8c5b2e1 --- /dev/null +++ b/trainers/__init__.py @@ -0,0 +1,2 @@ + +from trainers.basic import create_trainer diff --git a/trainers/basic.py b/trainers/basic.py new file mode 100644 index 0000000..75b65d2 --- /dev/null +++ b/trainers/basic.py @@ -0,0 +1,255 @@ +import os +from pathlib import Path + +import torch +import torch.nn as nn + +import ignite.distributed as idist +from ignite.contrib.engines import common +from ignite.contrib.handlers import ProgressBar +from ignite.engine import Engine, Events, create_supervised_evaluator +from ignite.metrics import Accuracy, Precision, Recall +from ignite.handlers import Checkpoint + +from dataflow import sup_prepare_batch, cycle +from ctaugment import stats, deserialize + + +def to_list_str(v): + if isinstance(v, torch.Tensor): + return " ".join(["%.2f" % i for i in v.tolist()]) + return "%.2f" % v + + +def create_trainer( + train_step, + output_names, + model, + ema_model, + optimizer, + lr_scheduler, + supervised_train_loader, + test_loader, + cfg, + logger, + cta=None, + unsup_train_loader=None, + cta_probe_loader=None, +): + + trainer = Engine(train_step) + trainer.logger = logger + + output_path = os.getcwd() + + to_save = { + "model": model, + "ema_model": ema_model, + "optimizer": optimizer, + "trainer": trainer, + "lr_scheduler": lr_scheduler, + } + if cta is not None: + to_save["cta"] = cta + + common.setup_common_training_handlers( + trainer, + train_sampler=supervised_train_loader.sampler, + to_save=to_save, + save_every_iters=cfg.solver.checkpoint_every, + output_path=output_path, + output_names=output_names, + lr_scheduler=lr_scheduler, + with_pbars=False, + clear_cuda_cache=False, + ) + + ProgressBar(persist=False).attach( + trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED + ) + + unsupervised_train_loader_iter = None + if unsup_train_loader is not None: + unsupervised_train_loader_iter = cycle(unsup_train_loader) + + cta_probe_loader_iter = None + if cta_probe_loader is not None: + cta_probe_loader_iter = cycle(cta_probe_loader) + + # Setup handler to prepare data batches + @trainer.on(Events.ITERATION_STARTED) + def prepare_batch(e): + sup_batch = e.state.batch + e.state.batch = { + "sup_batch": sup_batch, + } + if unsupervised_train_loader_iter is not None: + unsup_batch = next(unsupervised_train_loader_iter) + e.state.batch["unsup_batch"] = unsup_batch + + if cta_probe_loader_iter is not None: + cta_probe_batch = next(cta_probe_loader_iter) + cta_probe_batch["policy"] = [ + deserialize(p) for p in cta_probe_batch["policy"] + ] + e.state.batch["cta_probe_batch"] = cta_probe_batch + + # Setup handler to update EMA model + @trainer.on(Events.ITERATION_COMPLETED, cfg.ema_decay) + def update_ema_model(ema_decay): + # EMA on parametes + for ema_param, param in zip(ema_model.parameters(), model.parameters()): + ema_param.data.mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay) + + # Setup handlers for debugging + if cfg.debug: + + @trainer.on(Events.STARTED | Events.ITERATION_COMPLETED(every=100)) + @idist.one_rank_only() + def log_weights_norms(): + wn = [] + ema_wn = [] + for ema_param, param in zip(ema_model.parameters(), model.parameters()): + wn.append(torch.mean(param.data)) + ema_wn.append(torch.mean(ema_param.data)) + + msg = "\n\nWeights norms" + msg += "\n- Raw model: {}".format( + to_list_str(torch.tensor(wn[:10] + wn[-10:])) + ) + msg += "\n- EMA model: {}\n".format( + to_list_str(torch.tensor(ema_wn[:10] + ema_wn[-10:])) + ) + logger.info(msg) + + rmn = [] + rvar = [] + ema_rmn = [] + ema_rvar = [] + for m1, m2 in zip(model.modules(), ema_model.modules()): + if isinstance(m1, nn.BatchNorm2d) and isinstance(m2, nn.BatchNorm2d): + rmn.append(torch.mean(m1.running_mean)) + rvar.append(torch.mean(m1.running_var)) + ema_rmn.append(torch.mean(m2.running_mean)) + ema_rvar.append(torch.mean(m2.running_var)) + + msg = "\n\nBN buffers" + msg += "\n- Raw mean: {}".format(to_list_str(torch.tensor(rmn[:10]))) + msg += "\n- Raw var: {}".format(to_list_str(torch.tensor(rvar[:10]))) + msg += "\n- EMA mean: {}".format(to_list_str(torch.tensor(ema_rmn[:10]))) + msg += "\n- EMA var: {}\n".format(to_list_str(torch.tensor(ema_rvar[:10]))) + logger.info(msg) + + # TODO: Need to inspect a bug + # if idist.get_rank() == 0: + # from ignite.contrib.handlers import ProgressBar + # + # profiler = BasicTimeProfiler() + # profiler.attach(trainer) + # + # @trainer.on(Events.ITERATION_COMPLETED(every=200)) + # def log_profiling(_): + # results = profiler.get_results() + # profiler.print_results(results) + + # Setup validation engine + metrics = { + "accuracy": Accuracy(), + } + + if not (idist.has_xla_support and idist.backend() == idist.xla.XLA_TPU): + metrics.update({ + "precision": Precision(average=False), + "recall": Recall(average=False), + }) + + eval_kwargs = dict( + metrics=metrics, + prepare_batch=sup_prepare_batch, + device=idist.device(), + non_blocking=True, + ) + + evaluator = create_supervised_evaluator(model, **eval_kwargs) + ema_evaluator = create_supervised_evaluator(ema_model, **eval_kwargs) + + def log_results(epoch, max_epochs, metrics, ema_metrics): + msg1 = "\n".join( + ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in metrics.items()] + ) + msg2 = "\n".join( + ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in ema_metrics.items()] + ) + logger.info( + "\nEpoch {}/{}\nRaw:\n{}\nEMA:\n{}\n".format(epoch, max_epochs, msg1, msg2) + ) + if cta is not None: + logger.info("\n" + stats(cta)) + + @trainer.on( + Events.EPOCH_COMPLETED(every=cfg.solver.validate_every) + | Events.STARTED + | Events.COMPLETED + ) + def run_evaluation(): + evaluator.run(test_loader) + ema_evaluator.run(test_loader) + log_results( + trainer.state.epoch, + trainer.state.max_epochs, + evaluator.state.metrics, + ema_evaluator.state.metrics, + ) + + # setup TB logging + if idist.get_rank() == 0: + tb_logger = common.setup_tb_logging( + output_path, + trainer, + optimizers=optimizer, + evaluators={"validation": evaluator, "ema validation": ema_evaluator}, + log_every_iters=15, + ) + if cfg.online_exp_tracking.wandb: + from ignite.contrib.handlers import WandBLogger + + wb_dir = Path("/tmp/output-fixmatch-wandb") + if not wb_dir.exists(): + wb_dir.mkdir() + + _ = WandBLogger( + project="fixmatch-pytorch", + name=cfg.name, + config=cfg, + sync_tensorboard=True, + dir=wb_dir.as_posix(), + reinit=True, + ) + + resume_from = cfg.solver.resume_from + if resume_from is not None: + resume_from = list(Path(resume_from).rglob("training_checkpoint*.pt*")) + if len(resume_from) > 0: + # get latest + checkpoint_fp = max(resume_from, key=lambda p: p.stat().st_mtime) + assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format( + checkpoint_fp.as_posix() + ) + logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix())) + checkpoint = torch.load(checkpoint_fp.as_posix()) + Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) + + @trainer.on(Events.COMPLETED) + def release_all_resources(): + nonlocal unsupervised_train_loader_iter, cta_probe_loader_iter + + if idist.get_rank() == 0: + tb_logger.close() + + if unsupervised_train_loader_iter is not None: + unsupervised_train_loader_iter = None + + if cta_probe_loader_iter is not None: + cta_probe_loader_iter = None + + return trainer diff --git a/utils.py b/utils.py index 252a400..9aef693 100644 --- a/utils.py +++ b/utils.py @@ -1,386 +1,88 @@ -from functools import partial -import json -import random -from collections import OrderedDict - -import numpy as np - -import torch import torch.nn as nn -import torch.optim as optim -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data import Subset, Dataset, DataLoader - -from torchvision import transforms as T -from torchvision.datasets.cifar import CIFAR10 -from torchvision import models as tv_models - -from ignite.utils import convert_tensor - -from ctaugment import OPS, CTAugment, OP -from wrn import WideResNet -import dist_utils - -weak_transforms = T.Compose([ - T.Pad(4), - T.RandomCrop(32), - T.RandomHorizontalFlip(), - T.ToTensor(), - T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.25, 0.25, 0.25)) -]) - -test_transforms = T.Compose([ - T.ToTensor(), - T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.25, 0.25, 0.25)) -]) - -cutout_image_transforms = T.Compose([ - T.ToTensor(), - T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.25, 0.25, 0.25)), - T.RandomErasing(scale=(0.02, 0.15)) -]) - - -def set_seed(seed): - random.seed(seed) - torch.manual_seed(seed) - np.random.seed(seed) - - -class TransformedDataset(Dataset): - - def __init__(self, dataset, transforms): - self.dataset = dataset - self.transforms = transforms - - def __getitem__(self, i): - dp = self.dataset[i] - return self.transforms(dp) - - def __len__(self): - return len(self.dataset) - - -def get_supervised_trainset(root, num_train_samples_per_class=25, download=True): - num_classes = 10 - full_train_dataset = CIFAR10(root, train=True, download=download) - - supervised_train_indices = [] - counter = [0] * num_classes - - indices = list(range(len(full_train_dataset))) - random_indices = np.random.permutation(indices) - - for i in random_indices: - dp = full_train_dataset[i] - if len(supervised_train_indices) >= num_classes * num_train_samples_per_class: - break - if counter[dp[1]] < num_train_samples_per_class: - counter[dp[1]] += 1 - supervised_train_indices.append(i) - - return Subset(full_train_dataset, supervised_train_indices) - - -def get_supervised_trainset_0_250(root, download=True): - full_train_dataset = CIFAR10(root, train=True, download=download) - - supervised_train_indices = [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, - 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, - 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, - 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, - 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, - 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, - 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, - 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, - 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, - 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, - 166, 167, 169, 170, 171, 172, 173, 174, 175, 177, 178, - 179, 180, 181, 182, 183, 185, 186, 187, 188, 189, 190, - 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 202, - 203, 204, 205, 207, 209, 210, 211, 213, 215, 216, 217, - 218, 220, 221, 222, 223, 224, 228, 229, 230, 231, 233, - 237, 239, 240, 241, 244, 246, 247, 252, 254, 256, 259, - 260, 263, 264, 268, 271, 272, 276, 277, 279, 280, 281, - 284, 285, 290, 293, 296, 308, 317 - ] - return Subset(full_train_dataset, supervised_train_indices) - - -def get_supervised_train_loader(supervised_train_dataset, transforms=weak_transforms, **dataloader_kwargs): - dataloader_kwargs['pin_memory'] = True - dataloader_kwargs['drop_last'] = True - dataloader_kwargs['shuffle'] = dataloader_kwargs.get("sampler", None) is None - - supervised_train_loader = DataLoader( - TransformedDataset( - supervised_train_dataset, - transforms=lambda d: {"image": transforms(d[0]), "target": d[1]} - ), - **dataloader_kwargs - ) - return supervised_train_loader - - -def get_test_loader(root, transforms=test_transforms, **dataloader_kwargs): - full_test_dataset = CIFAR10(root, train=False, download=False) - - dataloader_kwargs['pin_memory'] = True - dataloader_kwargs['drop_last'] = False - dataloader_kwargs['shuffle'] = False - - if dist_utils.is_tpu_distributed(): - dataloader_kwargs["num_workers"] = 1 - dataloader_kwargs["sampler"] = DistributedSampler( - full_test_dataset, - num_replicas=dist_utils.get_world_size(), - rank=dist_utils.get_rank() - ) - - test_loader = DataLoader( - TransformedDataset( - full_test_dataset, - transforms=lambda dp: {"image": transforms(dp[0]), "target": dp[1]} - ), - **dataloader_kwargs - ) - - return test_loader - - -class StorableCTAugment(CTAugment): - - def load_state_dict(self, state): - for k in ["decay", "depth", "th", "rates"]: - assert k in state, "{} not in {}".format(k, state.keys()) - setattr(self, k, state[k]) - - def state_dict(self): - return OrderedDict([(k, getattr(self, k)) for k in ["decay", "depth", "th", "rates"]]) - - -def get_default_cta(): - return StorableCTAugment() +import ignite.distributed as idist -def cta_apply(pil_img, ops): - if ops is None: - return pil_img - for op, args in ops: - pil_img = OPS[op].f(pil_img, *args) - return pil_img +from hydra.utils import instantiate +from models import setup_model, setup_ema +from dataflow import ( + get_supervised_train_loader, + get_test_loader, + get_unsupervised_train_loader, + get_cta_probe_loader, +) -def deserialize(policy_str): - return [OP(f=x[0], bins=x[1]) for x in json.loads(policy_str)] +def initialize(cfg): + model = setup_model(cfg.model, num_classes=cfg.num_classes) + ema_model = setup_model(cfg.model, num_classes=cfg.num_classes) -def cta_image_transforms(pil_img, cta, transform=cutout_image_transforms): - policy = cta.policy(probe=False) - pil_img = cta_apply(pil_img, policy) - return transform(pil_img) - - -def cta_probe_transforms(dp, cta, image_transforms=cutout_image_transforms): - policy = cta.policy(probe=True) - probe = cta_apply(dp[0], policy) - probe = image_transforms(probe) - return { - "image": probe, - "target": dp[1], - "policy": json.dumps(policy) - } - - -def get_cta_probe_loader(supervised_train_dataset, cta, **dataloader_kwargs): - dataloader_kwargs['pin_memory'] = True - dataloader_kwargs['drop_last'] = False - dataloader_kwargs['shuffle'] = dataloader_kwargs.get("sampler", None) is None - - cta_probe_loader = DataLoader( - TransformedDataset( - supervised_train_dataset, - transforms=partial(cta_probe_transforms, cta=cta) - ), - **dataloader_kwargs - ) + model.to(idist.device()) + ema_model.to(idist.device()) + setup_ema(ema_model, model) - return cta_probe_loader + model = idist.auto_model(model) + if isinstance(model, nn.parallel.DataParallel): + ema_model = nn.parallel.DataParallel(ema_model) -def get_unsupervised_train_loader(raw_dataset, transforms_weak, transforms_strong, **dataloader_kwargs): - unsupervised_train_dataset = TransformedDataset( - raw_dataset, - transforms=lambda dp: {"image": transforms_weak(dp[0]), "strong_aug": transforms_strong(dp[0])} - ) + optimizer = instantiate(cfg.solver.optimizer, model.parameters()) + optimizer = idist.auto_optim(optimizer) - dataloader_kwargs['pin_memory'] = True - dataloader_kwargs['drop_last'] = True - dataloader_kwargs['shuffle'] = dataloader_kwargs.get("sampler", None) is None + sup_criterion = instantiate(cfg.solver.supervised_criterion) - unsupervised_train_loader = DataLoader( - unsupervised_train_dataset, - **dataloader_kwargs + total_num_iters = cfg.solver.num_epochs * cfg.solver.epoch_length + lr_scheduler = instantiate( + cfg.solver.lr_scheduler, optimizer, T_max=total_num_iters ) - return unsupervised_train_loader - - -def sup_prepare_batch(batch, device, non_blocking): - x = convert_tensor(batch["image"], device, non_blocking) - y = convert_tensor(batch["target"], device, non_blocking) - return x, y - - -def cycle(dataloader): - while True: - for b in dataloader: - yield b - - -def stats(cta): - return '\n'.join('%-16s %s' % (k, ' / '.join(' '.join('%.2f' % x for x in cta.rate_to_p(rate)) - for rate in cta.rates[k])) - for k in sorted(OPS.keys())) - - -def interleave(x, batch, inverse=False): - """ - TF code - def interleave(x, batch): - s = x.get_shape().as_list() - return tf.reshape(tf.transpose(tf.reshape(x, [-1, batch] + s[1:]), [1, 0] + list(range(2, 1+len(s)))), [-1] + s[1:]) - """ - shape = x.shape - axes = [batch, -1] if inverse else [-1, batch] - return x.reshape(*axes, *shape[1:]).transpose(0, 1).reshape(-1, *shape[1:]) - - -def deinterleave(x, batch): - return interleave(x, batch, inverse=True) - - -def setup_ema(ema_model, ref_model): - ema_model.load_state_dict(ref_model.state_dict()) - for param in ema_model.parameters(): - param.detach_() - # set EMA model's BN buffers as base model BN buffers: - for m1, m2 in zip(ref_model.modules(), ema_model.modules()): - if isinstance(m1, nn.BatchNorm2d) and isinstance(m2, nn.BatchNorm2d): - m2.running_mean = m1.running_mean - m2.running_var = m1.running_var + return model, ema_model, optimizer, sup_criterion, lr_scheduler -def to_list_str(v): - if isinstance(v, torch.Tensor): - return " ".join(["%.2f" % i for i in v.tolist()]) - return "%.2f" % v +def get_dataflow(cfg, cta=None, with_unsup=False): -def get_dataflow_iters(config, cta, distributed=False): - batch_size = config["batch_size"] - num_workers = config["num_workers"] - - # Rescale batch_size and num_workers - batch_size //= dist_utils.get_world_size() - - if distributed: - nproc_per_node = dist_utils.get_num_proc_per_node() - num_workers = int((num_workers + nproc_per_node - 1) / nproc_per_node) - - num_workers //= 3 # 3 dataloaders - - # Setup dataflow - if config["num_train_samples_per_class"] == 25: - supervised_train_dataset = get_supervised_trainset_0_250(config["data_path"]) - else: - supervised_train_dataset = get_supervised_trainset( - config["data_path"], - config["num_train_samples_per_class"] - ) - - dist_sampler_kwargs = {"num_replicas": dist_utils.get_world_size(), "rank": dist_utils.get_rank()} - - supervised_train_loader = get_supervised_train_loader( - supervised_train_dataset, - transforms=weak_transforms, - batch_size=batch_size, - num_workers=num_workers, - sampler=DistributedSampler(supervised_train_dataset, **dist_sampler_kwargs) if distributed else None + num_workers = ( + cfg.dataflow.num_workers if cta is None else cfg.dataflow.num_workers // 2 ) - cta_probe_loader = get_cta_probe_loader( - supervised_train_dataset, - cta=cta, - batch_size=batch_size, + sup_train_loader = get_supervised_train_loader( + cfg.dataflow.name, + root=cfg.dataflow.data_path, + num_train_samples_per_class=cfg.ssl.num_train_samples_per_class, + batch_size=cfg.dataflow.batch_size, num_workers=num_workers, - sampler=DistributedSampler(supervised_train_dataset, **dist_sampler_kwargs) if distributed else None ) - full_train_dataset = CIFAR10(config["data_path"], train=True) - unsupervised_train_loader = get_unsupervised_train_loader( - full_train_dataset, - transforms_weak=weak_transforms, - transforms_strong=partial(cta_image_transforms, cta=cta), - batch_size=batch_size * config["mu_ratio"], - num_workers=num_workers, - sampler=DistributedSampler(full_train_dataset, **dist_sampler_kwargs) if distributed else None + test_loader = get_test_loader( + cfg.dataflow.name, + root=cfg.dataflow.data_path, + batch_size=cfg.dataflow.batch_size, + num_workers=cfg.dataflow.num_workers, ) - # Setup training/validation loops - supervised_train_loader_iter = cycle(supervised_train_loader) - unsupervised_train_loader_iter = cycle(unsupervised_train_loader) - cta_probe_loader_iter = cycle(cta_probe_loader) - - return supervised_train_loader_iter, unsupervised_train_loader_iter, cta_probe_loader_iter - - -def get_models_optimizer(config, distributed=False): - device = config["device"] - if device == "xla": - import torch_xla.core.xla_model as xm - device = xm.xla_device() - - # Setup model, optimizer and setup EMA model - if config["model"] == "WRN-28-2": - model = WideResNet(num_classes=10) - ema_model = WideResNet(num_classes=10) - else: - name = config["model"] - if name in tv_models.__dict__: - fn = tv_models.__dict__[name] - else: - raise RuntimeError("Unknown model name {}".format(name)) - model = fn(num_classes=10) - ema_model = fn(num_classes=10) - - model.to(device) - ema_model.to(device) - setup_ema(ema_model, model) - - optimizer = optim.SGD( - model.parameters(), - lr=config["learning_rate"], - momentum=config["momentum"], - weight_decay=config["weight_decay"], - nesterov=True, - ) - - if device == "cuda" and (config["with_nv_amp_level"] is not None): - assert config["with_nv_amp_level"] in ("O1", "O2") - from apex import amp - models, optimizer = amp.initialize([model, ema_model], optimizer, opt_level=config["with_nv_amp_level"]) - model, ema_model = models + unsup_train_loader = None + if with_unsup: + if cta is None: + raise ValueError( + "If with_unsup=True, cta should be defined, but given None" + ) + unsup_train_loader = get_unsupervised_train_loader( + cfg.dataflow.name, + root=cfg.dataflow.data_path, + cta=cta, + batch_size=int(cfg.dataflow.batch_size * cfg.ssl.mu_ratio), + num_workers=num_workers, + ) - if distributed and device == "cuda": - model = DDP(model, device_ids=[config["local_rank"], ]) - # ema_model has no grads => DDP wont work - elif config["with_nv_amp_level"] in ("O1", None) and device == "cuda" and torch.cuda.device_count() > 0: - model = nn.parallel.DataParallel(model) - ema_model = nn.parallel.DataParallel(ema_model) + cta_probe_loader = None + if cta is not None: + cta_probe_loader = get_cta_probe_loader( + cfg.dataflow.name, + root=cfg.dataflow.data_path, + num_train_samples_per_class=cfg.ssl.num_train_samples_per_class, + cta=cta, + batch_size=int(cfg.dataflow.batch_size * cfg.ssl.mu_ratio), + num_workers=num_workers, + ) - return model, ema_model, optimizer + return sup_train_loader, test_loader, unsup_train_loader, cta_probe_loader