Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 28 additions & 54 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,85 +3,59 @@
Experiments with "FixMatch" on Cifar10 dataset.

Based on ["FixMatch: Simplifying Semi-Supervised Learning withConsistency and Confidence"](https://arxiv.org/abs/2001.07685)
and its official [code](https://github.com/google-research/fixmatch).

## Requirements

```bash
pip install --upgrade --pre hydra-core tensorboardX
pip install --upgrade --pre pytorch-ignite
```

## Training

```bash
python -u main_fixmatch.py
# or python -u main_fixmatch.py --params "data_path=/path/to/cifar10"
python -u main_fixmatch.py model=WRN-28-2
```

This script automatically trains in multiple GPUs (`torch.nn.DistributedParallel`).

### Distributed Data Parallel (DDP) on multiple GPUs (Experimental)
If it is needed to specify input/output folder :
```
python -u main_fixmatch.py dataflow.data_path=/data/cifar10/ hydra.run.dir=/output-fixmatch model=WRN-28-2
```

For example, training on 2 GPUs
To use wandb logger, we need login and run with `online_exp_tracking.wandb=true`:
```bash
python -u -m torch.distributed.launch --nproc_per_node=2 main_fixmatch.py --params="distributed=True"
wandb login <token>
python -u main_fixmatch.py model=WRN-28-2 online_exp_tracking.wandb=true
```

### TPU(s) on Colab (Experimental)

#### Installation
To see other options:
```bash
VERSION = "1.5"
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION
python -u main_fixmatch.py --help
```

#### Single TPU
### Training curves visualization

By default, we use Tensorboard to log training curves

```bash
python -u main_fixmatch.py --params="device='xla'"
tensorboard --logdir=/tmp/output-fixmatch-cifar10-hydra/
```

#### 8 TPUs on Colab

### Distributed Data Parallel (DDP) on multiple GPUs (Experimental)

For example, training on 2 GPUs
```bash
python -u main_fixmatch.py --params="device='xla';distributed=True"
python -u -m torch.distributed.launch --nproc_per_node=2 main_fixmatch.py model=WRN-28-2 distributed.backend=nccl
```

## TODO

* [x] Resume training from existing checkpoint:
* [x] save/load CTA
* [x] save ema model

* [ ] DDP:
* [x] Synchronize CTA across processes
* [x] Unified GPU and TPU approach
* [ ] Bug: DDP performances are worse than DP on the first epochs

* [ ] Logging to an online platform: NeptuneML or Trains or W&B

* [ ] Replace PIL augmentations with Albumentations

```python
class BlurLimitSampler:
def __init__(self, blur, weights):
self.blur = blur # [3, 5, 7]
self.weights = weights # [0.1, 0.5, 0.4]
def get_params(self):
return {"ksize": int(random.choice(self.blur, p=self.weights))}

class Blur(ImageOnlyTransform):
def __init__(self, blur_limit, always_apply=False, p=0.5):
super(Blur, self).__init__(always_apply, p)
self.blur_limit = blur_limit

def apply(self, image, ksize=3, **params):
return F.blur(image, ksize)

def get_params(self):
if isinstance(self.blur_limit, BlurLimitSampler):
return self.blur_limit.get_params()
return {"ksize": int(random.choice(np.arange(self.blur_limit[0], self.blur_limit[1] + 1, 2)))}

def get_transform_init_args_names(self):
return ("blur_limit",)
```
### TPU(s) on Colab (Experimental) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1ZoWz1-a3bpj1xMxpM2K2qQ4Y9xvtdGWO)

For example, training on 8 TPUs in distributed mode:
```bash
python -u main_fixmatch.py model=resnet18 distributed.backend=xla-tpu distributed.nproc_per_node=8
# or python -u main_fixmatch.py model=WRN-28-2 distributed.backend=xla-tpu distributed.nproc_per_node=8
```
38 changes: 38 additions & 0 deletions TODO
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
## TODO

* [x] Resume training from existing checkpoint:
* [x] save/load CTA
* [x] save ema model

* [ ] DDP:
* [x] Synchronize CTA across processes
* [ ] Bug: DDP performances are worse than DP on the first epochs

* [x] Logging to an online platform: W&B

* [ ] Replace PIL augmentations with Albumentations

```python
class BlurLimitSampler:
def __init__(self, blur, weights):
self.blur = blur # [3, 5, 7]
self.weights = weights # [0.1, 0.5, 0.4]
def get_params(self):
return {"ksize": int(random.choice(self.blur, p=self.weights))}

class Blur(ImageOnlyTransform):
def __init__(self, blur_limit, always_apply=False, p=0.5):
super(Blur, self).__init__(always_apply, p)
self.blur_limit = blur_limit

def apply(self, image, ksize=3, **params):
return F.blur(image, ksize)

def get_params(self):
if isinstance(self.blur_limit, BlurLimitSampler):
return self.blur_limit.get_params()
return {"ksize": int(random.choice(np.arange(self.blur_limit[0], self.blur_limit[1] + 1, 2)))}

def get_transform_init_args_names(self):
return ("blur_limit",)
```
Loading