|
3 | 3 | import torch
|
4 | 4 | import torchvision
|
5 | 5 | import torch.nn as nn
|
6 |
| -import LIVEFolder |
7 | 6 | from SCNN import SCNN
|
8 | 7 | from PIL import Image
|
9 | 8 | from scipy import stats
|
@@ -163,15 +162,24 @@ def __init__(self, options, path):
|
163 | 162 | ])
|
164 | 163 |
|
165 | 164 |
|
166 |
| - if self._options['dataset'] == 'live': |
| 165 | + if self._options['dataset'] == 'live': |
| 166 | + import LIVEFolder |
167 | 167 | train_data = LIVEFolder.LIVEFolder(
|
168 | 168 | root=self._path['live'], loader = default_loader, index = self._options['train_index'],
|
169 | 169 | transform=train_transforms)
|
170 | 170 | test_data = LIVEFolder.LIVEFolder(
|
171 | 171 | root=self._path['live'], loader = default_loader, index = self._options['test_index'],
|
172 | 172 | transform=test_transforms)
|
| 173 | + elif self._options['dataset'] == 'livec': |
| 174 | + import LIVEChallengeFolder |
| 175 | + train_data = LIVEChallengeFolder.LIVEChallengeFolder( |
| 176 | + root=self._path['livec'], loader = default_loader, index = self._options['train_index'], |
| 177 | + transform=train_transforms) |
| 178 | + test_data = LIVEChallengeFolder.LIVEChallengeFolder( |
| 179 | + root=self._path['livec'], loader = default_loader, index = self._options['test_index'], |
| 180 | + transform=test_transforms) |
173 | 181 | else:
|
174 |
| - raise AttributeError('Only support LIVE right now!') |
| 182 | + raise AttributeError('Only support LIVE and LIVEC right now!') |
175 | 183 | self._train_loader = torch.utils.data.DataLoader(
|
176 | 184 | train_data, batch_size=self._options['batch_size'],
|
177 | 185 | shuffle=True, num_workers=0, pin_memory=True)
|
|
0 commit comments