Skip to content

Commit 6da72c0

Browse files
committed
Custom Trainer with COCO eval for val
1 parent fee4005 commit 6da72c0

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

src/data_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from config import ProjConfig
77

88

9-
def register_isaid_truck_data(extra_meta={}, register_val=True, register_test):
9+
def register_isaid_truck_data(extra_meta={}, register_val=True, register_test=False):
1010
"""
1111
register project data with name isaid_truck_train/val
1212
"""

src/model_train.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,35 @@
11
# detectron
2+
import torch
23
import detectron2
34
from detectron2.utils.logger import setup_logger
45
setup_logger()
5-
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
6+
from detectron2.engine import DefaultTrainer
67
from detectron2.config import get_cfg
8+
from detectron2 import model_zoo
9+
from detectron2.evaluation import COCOEvaluator
710

811
# common imports
912
import os
13+
from datetime import datetime
1014

1115
# custom utilities
1216
from config import ProjConfig
1317
from data_utils import register_isaid_truck_data
1418

1519

16-
def setup_train_config():
20+
def setup_train_config(train_data_name, val_data_name=None, output_dir=None):
1721
"""
1822
Specify the model training configuration
1923
"""
2024
cfg = get_cfg()
2125
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
22-
cfg.DATASETS.TRAIN = ("isaid_truck_train",)
23-
cfg.DATASETS.TEST = ("isaid_truck_val", )
24-
cfg.TEST.EVAL_PERIOD = 1 # how often to eval val
26+
cfg.DATASETS.TRAIN = (train_data_name,)
27+
if val_data_name:
28+
cfg.DATASETS.TEST = (val_data_name, )
29+
cfg.TEST.EVAL_PERIOD = 1 # how often to eval val
2530
cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
2631
cfg.DATALOADER.NUM_WORKERS = 1
27-
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
32+
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
2833
cfg.MODEL.BACKBONE.FREEZE_AT = 2 # freeze the first X stages of backbone
2934
cfg.SOLVER.IMS_PER_BATCH = 2
3035
cfg.SOLVER.BASE_LR = 0.00025
@@ -34,18 +39,40 @@ def setup_train_config():
3439
cfg.SOLVER.CHECKPOINT_PERIOD = 5 # Save a checkpoint after every this number of iterations
3540
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # default 512, smaller numbers are faster
3641
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # only has one class
42+
if output_dir:
43+
# specify an output with a few key hyper params
44+
cfg.OUTPUT_DIR = os.path.join(output_dir, \
45+
f'detectron_{datetime.now().strftime("%Y%m%d%H%M%S")}_freeze{cfg.MODEL.BACKBONE.FREEZE_AT}_batchsize{cfg.SOLVER.IMS_PER_BATCH}_lr{cfg.SOLVER.BASE_LR}')
46+
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
3747
return cfg
3848

3949

50+
class TrainerWithVal(DefaultTrainer):
51+
"""
52+
Build the appropriate evaluate is needed with train with a validation set
53+
"""
54+
55+
@classmethod
56+
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
57+
"""class method for evaluating the validation set"""
58+
if output_folder is None:
59+
output_folder = os.path.join(cfg.OUTPUT_DIR,"inference")
60+
return COCOEvaluator(dataset_name, output_dir=output_folder)
61+
62+
4063
def main():
4164
# configure the data
4265
proj_config = ProjConfig()
4366
_ = register_isaid_truck_data(extra_meta={}, register_val=True)
4467

45-
cfg = setup_train_config()
68+
cfg = setup_train_config(
69+
proj_config.train_data_name,
70+
proj_config.val_data_name,
71+
proj_config.model_dir
72+
)
4673
# set up the trainer: wrapper for model training with config
47-
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
48-
trainer = DefaultTrainer(cfg)
74+
75+
trainer = TrainerWithVal(cfg)
4976
trainer.resume_or_load(resume=False)
5077
trainer.train()
5178

@@ -56,4 +83,4 @@ def main():
5683

5784
# # Look at training curves in tensorboard:
5885
# %load_ext tensorboard
59-
# %tensorboard --logdir output
86+
# %tensorboard --logdir models

0 commit comments

Comments
 (0)