1
1
# detectron
2
+ import torch
2
3
import detectron2
3
4
from detectron2 .utils .logger import setup_logger
4
5
setup_logger ()
5
- from detectron2 .engine import DefaultTrainer , default_argument_parser , default_setup , launch
6
+ from detectron2 .engine import DefaultTrainer
6
7
from detectron2 .config import get_cfg
8
+ from detectron2 import model_zoo
9
+ from detectron2 .evaluation import COCOEvaluator
7
10
8
11
# common imports
9
12
import os
13
+ from datetime import datetime
10
14
11
15
# custom utilities
12
16
from config import ProjConfig
13
17
from data_utils import register_isaid_truck_data
14
18
15
19
16
- def setup_train_config ():
20
+ def setup_train_config (train_data_name , val_data_name = None , output_dir = None ):
17
21
"""
18
22
Specify the model training configuration
19
23
"""
20
24
cfg = get_cfg ()
21
25
cfg .merge_from_file (model_zoo .get_config_file ("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml" ))
22
- cfg .DATASETS .TRAIN = ("isaid_truck_train" ,)
23
- cfg .DATASETS .TEST = ("isaid_truck_val" , )
24
- cfg .TEST .EVAL_PERIOD = 1 # how often to eval val
26
+ cfg .DATASETS .TRAIN = (train_data_name ,)
27
+ if val_data_name :
28
+ cfg .DATASETS .TEST = (val_data_name , )
29
+ cfg .TEST .EVAL_PERIOD = 1 # how often to eval val
25
30
cfg .MODEL .DEVICE = 'cuda' if torch .cuda .is_available () else 'cpu'
26
31
cfg .DATALOADER .NUM_WORKERS = 1
27
- cfg .MODEL .WEIGHTS = model_zoo .get_checkpoint_url ("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml" )
32
+ cfg .MODEL .WEIGHTS = model_zoo .get_checkpoint_url ("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml" )
28
33
cfg .MODEL .BACKBONE .FREEZE_AT = 2 # freeze the first X stages of backbone
29
34
cfg .SOLVER .IMS_PER_BATCH = 2
30
35
cfg .SOLVER .BASE_LR = 0.00025
@@ -34,18 +39,40 @@ def setup_train_config():
34
39
cfg .SOLVER .CHECKPOINT_PERIOD = 5 # Save a checkpoint after every this number of iterations
35
40
cfg .MODEL .ROI_HEADS .BATCH_SIZE_PER_IMAGE = 128 # default 512, smaller numbers are faster
36
41
cfg .MODEL .ROI_HEADS .NUM_CLASSES = 1 # only has one class
42
+ if output_dir :
43
+ # specify an output with a few key hyper params
44
+ cfg .OUTPUT_DIR = os .path .join (output_dir , \
45
+ f'detectron_{ datetime .now ().strftime ("%Y%m%d%H%M%S" )} _freeze{ cfg .MODEL .BACKBONE .FREEZE_AT } _batchsize{ cfg .SOLVER .IMS_PER_BATCH } _lr{ cfg .SOLVER .BASE_LR } ' )
46
+ os .makedirs (cfg .OUTPUT_DIR , exist_ok = True )
37
47
return cfg
38
48
39
49
50
+ class TrainerWithVal (DefaultTrainer ):
51
+ """
52
+ Build the appropriate evaluate is needed with train with a validation set
53
+ """
54
+
55
+ @classmethod
56
+ def build_evaluator (cls , cfg , dataset_name , output_folder = None ):
57
+ """class method for evaluating the validation set"""
58
+ if output_folder is None :
59
+ output_folder = os .path .join (cfg .OUTPUT_DIR ,"inference" )
60
+ return COCOEvaluator (dataset_name , output_dir = output_folder )
61
+
62
+
40
63
def main ():
41
64
# configure the data
42
65
proj_config = ProjConfig ()
43
66
_ = register_isaid_truck_data (extra_meta = {}, register_val = True )
44
67
45
- cfg = setup_train_config ()
68
+ cfg = setup_train_config (
69
+ proj_config .train_data_name ,
70
+ proj_config .val_data_name ,
71
+ proj_config .model_dir
72
+ )
46
73
# set up the trainer: wrapper for model training with config
47
- os . makedirs ( cfg . OUTPUT_DIR , exist_ok = True )
48
- trainer = DefaultTrainer (cfg )
74
+
75
+ trainer = TrainerWithVal (cfg )
49
76
trainer .resume_or_load (resume = False )
50
77
trainer .train ()
51
78
@@ -56,4 +83,4 @@ def main():
56
83
57
84
# # Look at training curves in tensorboard:
58
85
# %load_ext tensorboard
59
- # %tensorboard --logdir output
86
+ # %tensorboard --logdir models
0 commit comments