Skip to content

Commit 5694814

Browse files
committed
config
1 parent f747654 commit 5694814

File tree

7 files changed

+25
-15
lines changed

7 files changed

+25
-15
lines changed

config.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import models.darts.genotypes as gt
55
from functools import partial
66
import torch
7+
import time
78

89

910
def get_parser(name):
@@ -61,20 +62,23 @@ def build_parser(self):
6162
parser.add_argument('--layers', type=int, default=20, help='# of layers')
6263
parser.add_argument('--seed', type=int, default=2, help='random seed')
6364
parser.add_argument('--workers', type=int, default=4, help='# of workers')
64-
parser.add_argument('--aux_weight', type=float, default=0.4, help='auxiliary loss weight')
65+
parser.add_argument('--aux_weight', type=float, default=0, help='auxiliary loss weight')
6566
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
66-
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path prob')
67+
# parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path prob')
68+
parser.add_argument('--drop_path_prob', type=float, default=0, help='drop path prob')
6769

6870
parser.add_argument('--genotype', default='', help='Cell genotype')
71+
parser.add_argument('--deterministic', type=bool, default=True, help='momentum')
6972

7073
return parser
7174

7275
def __init__(self):
7376
parser = self.build_parser()
7477
args = parser.parse_args()
7578
super().__init__(**vars(args))
76-
77-
self.path = os.path.join('expreiments', self.model_method + '_' + self.model_name)
79+
time_str = time.asctime(time.localtime()).replace(' ', '_')
80+
self.path = os.path.join('/userhome/project/pytorch_image_classification/expreiments', self.model_method + '_'
81+
+ self.model_name + '_' + time_str)
7882
if len(self.genotype) > 1:
7983
self.genotype = gt.from_str(self.genotype)
8084
else:

models/darts/augment_cnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
""" CNN for network augmentation """
22
import torch.nn as nn
3-
from models.darts.augment_cells import AugmentCell
4-
from models.darts import ops
3+
from .augment_cells import AugmentCell
4+
from . import ops
55

66

77
class AuxiliaryHead(nn.Module):

models/darts/genotypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import namedtuple
77
import torch
88
import torch.nn as nn
9-
from models.darts import ops
9+
from . import ops
1010

1111

1212
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')

models/darts/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
""" Operations """
22
import torch
33
import torch.nn as nn
4-
import models.darts.genotypes as gt
4+
from . import genotypes as gt
55

66

77
OPS = {

models/get_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@
123123
'DARTS_V1': DARTS_V1,
124124
'DARTS_V2': DARTS_V2}
125125

126-
Manual_model_dict = {'resnet18': resnet18.ResNet18()}
126+
Manual_model_dict = {'Resnet18': resnet18.ResNet18()}
127127

128128

129129
def get_model(method, name):
@@ -134,6 +134,6 @@ def get_model(method, name):
134134
else:
135135
raise NotImplementedError
136136
if name in model_dict:
137-
return model_dict[method]
137+
return model_dict[name]
138138
else:
139139
raise NotImplementedError

train.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
writer = SummaryWriter(log_dir=os.path.join(config.path, "tb"))
2020
writer.add_text('config', config.as_markdown(), 0)
2121

22-
logger = utils.get_logger(os.path.join(config.path, "{}.log".format(config.name)))
22+
logger = utils.get_logger(os.path.join(config.path, "logger.log"))
2323
config.print_params(logger.info)
2424

2525

@@ -34,7 +34,12 @@ def main():
3434
torch.manual_seed(config.seed)
3535
torch.cuda.manual_seed_all(config.seed)
3636

37-
torch.backends.cudnn.benchmark = True
37+
if config.deterministic:
38+
torch.backends.cudnn.benchmark = False
39+
torch.backends.cudnn.deterministic = True
40+
torch.backends.cudnn.enabled = True
41+
else:
42+
torch.backends.cudnn.benchmark = True
3843

3944
# get data with meta info
4045
input_size, input_channels, n_classes, train_data, valid_data = get_data.get_data(
@@ -77,8 +82,9 @@ def main():
7782
# training loop
7883
for epoch in range(config.epochs):
7984
lr_scheduler.step()
80-
drop_prob = config.drop_path_prob * epoch / config.epochs
81-
model.module.drop_path_prob(drop_prob)
85+
if config.drop_path_prob > 0:
86+
drop_prob = config.drop_path_prob * epoch / config.epochs
87+
model.module.drop_path_prob(drop_prob)
8288

8389
# training
8490
train(train_loader, model, optimizer, criterion, epoch)

utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def netParams(model):
7373
p *= parameter.size(j)
7474
total_paramters += p
7575

76-
return total_paramters
76+
return total_paramters / 1024. /1024.
7777

7878

7979
class AverageMeter(object):

0 commit comments

Comments
 (0)