1+ from PIL .Image import TRANSPOSE
2+
3+
4+ norm_cfg = dict (type = 'SyncBN' , requires_grad = True )
5+ num_classes = 3
6+ dataset_type = 'TDDataset'
7+ data_root = 'data/td/'
8+ img_norm_cfg = dict (
9+ mean = [123.675 , 116.28 , 103.53 ], std = [58.395 , 57.12 , 57.375 ], to_rgb = True )
10+ #crop_size = (640, 640)
11+ #img_scale = (2048, 640)
12+ img_scale = (960 , 540 )
13+ #img_scale = (640, 640)
14+ keep_ratio = True
15+ model = dict (
16+ type = 'EncoderDecoder' ,
17+ pretrained = 'pretrain/mit_b5.pth' ,
18+ backbone = dict (
19+ type = 'MixVisionTransformer' ,
20+ in_channels = 3 ,
21+ embed_dims = 64 ,
22+ num_stages = 4 ,
23+ num_layers = [3 , 6 , 40 , 3 ],
24+ num_heads = [1 , 2 , 5 , 8 ],
25+ patch_sizes = [7 , 3 , 3 , 3 ],
26+ sr_ratios = [8 , 4 , 2 , 1 ],
27+ out_indices = (0 , 1 , 2 , 3 ),
28+ mlp_ratio = 4 ,
29+ qkv_bias = True ,
30+ drop_rate = 0.0 ,
31+ attn_drop_rate = 0.0 ,
32+ drop_path_rate = 0.1 ),
33+ decode_head = dict (
34+ type = 'SegformerHead' ,
35+ in_channels = [64 , 128 , 320 , 512 ],
36+ in_index = [0 , 1 , 2 , 3 ],
37+ channels = 256 ,
38+ dropout_ratio = 0.1 ,
39+ num_classes = num_classes ,
40+ norm_cfg = dict (type = 'SyncBN' , requires_grad = True ),
41+ align_corners = False ,
42+ loss_decode = dict (
43+ type = 'CrossEntropyLoss' , use_sigmoid = False , loss_weight = 1.0 )),
44+ train_cfg = dict (),
45+ test_cfg = dict (mode = 'whole' ))
46+
47+ train_pipeline = [
48+ dict (type = 'LoadImageFromFile' ),
49+ dict (type = 'LoadAnnotations' ),
50+ dict (type = 'Resize' , img_scale = img_scale , keep_ratio = keep_ratio ),
51+ #dict(type='RandomCrop', crop_size=(640, 640), cat_max_ratio=0.75),
52+ dict (type = 'RandomFlip' , prob = 0.5 ),
53+ dict (type = 'PhotoMetricDistortion' ),
54+ dict (
55+ type = 'Normalize' ,
56+ mean = [123.675 , 116.28 , 103.53 ],
57+ std = [58.395 , 57.12 , 57.375 ],
58+ to_rgb = True ),
59+ #dict(type='Pad', size=(640, 640), pad_val=0, seg_pad_val=255),
60+ dict (type = 'DefaultFormatBundle' ),
61+ dict (type = 'Collect' , keys = ['img' , 'gt_semantic_seg' ])
62+ ]
63+ test_pipeline = [
64+ dict (type = 'LoadImageFromFile' ),
65+ dict (
66+ type = 'MultiScaleFlipAug' ,
67+ img_scale = img_scale ,
68+ flip = False ,
69+ transforms = [
70+ dict (type = 'Resize' , keep_ratio = keep_ratio ),
71+ dict (type = 'RandomFlip' ),
72+ dict (
73+ type = 'Normalize' ,
74+ mean = [123.675 , 116.28 , 103.53 ],
75+ std = [58.395 , 57.12 , 57.375 ],
76+ to_rgb = True ),
77+ dict (type = 'ImageToTensor' , keys = ['img' ]),
78+ dict (type = 'Collect' , keys = ['img' ])
79+ ])
80+ ]
81+ data = dict (
82+ samples_per_gpu = 1 ,
83+ workers_per_gpu = 1 ,
84+ train = dict (
85+ type = dataset_type ,
86+ data_root = data_root ,
87+ img_dir = 'images' ,
88+ ann_dir = 'annotations/train_AW.json' ,
89+ pipeline = train_pipeline ),
90+ val = dict (
91+ type = dataset_type ,
92+ data_root = data_root ,
93+ img_dir = 'images' ,
94+ ann_dir = 'annotations/test_AW.json' ,
95+ pipeline = test_pipeline ),
96+ test = dict (
97+ type = dataset_type ,
98+ data_root = data_root ,
99+ img_dir = 'images' ,
100+ ann_dir = 'annotations/test_AW.json' ,
101+ pipeline = test_pipeline ))
102+ log_config = dict (
103+ interval = 50 , hooks = [dict (type = 'TextLoggerHook' , by_epoch = False )])
104+ dist_params = dict (backend = 'nccl' )
105+ log_level = 'INFO'
106+ load_from = None
107+ resume_from = None
108+ workflow = [('train' , 1 )]
109+ cudnn_benchmark = True
110+ optimizer = dict (
111+ type = 'AdamW' ,
112+ lr = 6e-06 ,
113+ betas = (0.9 , 0.999 ),
114+ weight_decay = 0.01 ,
115+ paramwise_cfg = dict (
116+ custom_keys = dict (
117+ pos_block = dict (decay_mult = 0.0 ),
118+ norm = dict (decay_mult = 0.0 ),
119+ head = dict (lr_mult = 10.0 ))))
120+ optimizer_config = dict ()
121+ lr_config = dict (
122+ policy = 'poly' ,
123+ warmup = 'linear' ,
124+ warmup_iters = 3200 ,
125+ warmup_ratio = 1e-06 ,
126+ power = 1.0 ,
127+ min_lr = 0.0 ,
128+ by_epoch = False )
129+ runner = dict (type = 'IterBasedRunner' , max_iters = 40000 )
130+ checkpoint_config = dict (by_epoch = False , interval = 4000 )
131+ evaluation = dict (interval = 4000 , metric = 'mIoU' , pre_eval = True )
132+ work_dir = data_root + 'work_dirs/full_segformer_mit-b5_640x640_160k_td_nbg_960/'
0 commit comments