Skip to content

Commit 6a49c73

Browse files
committed
update train test for doc_seg
1 parent e3a0b75 commit 6a49c73

File tree

12 files changed

+178
-54
lines changed

12 files changed

+178
-54
lines changed

configs/_base_/datasets/doc_seg.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# dataset settings
2+
dataset_type = 'doc_segDataset'
3+
data_root = '/data_backup/cuongnd/mmseg/doc_seg'
4+
img_norm_cfg = dict(
5+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6+
crop_size = (640, 640)
7+
train_pipeline = [
8+
dict(type='LoadImageFromFile'),
9+
dict(type='LoadAnnotations', reduce_zero_label=False),
10+
dict(type='Resize', img_scale=(1270, 900), ratio_range=(0.9, 1.1)),
11+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12+
dict(type='RandomFlip', flip_ratio=0.5),
13+
dict(type='PhotoMetricDistortion'),
14+
dict(type='Normalize', **img_norm_cfg),
15+
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
16+
dict(type='DefaultFormatBundle'),
17+
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
18+
]
19+
test_pipeline = [
20+
dict(type='LoadImageFromFile'),
21+
dict(
22+
type='MultiScaleFlipAug',
23+
img_scale=(1270, 900),
24+
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
25+
flip=False,
26+
transforms=[
27+
dict(type='Resize', keep_ratio=True),
28+
dict(type='RandomFlip'),
29+
dict(type='Normalize', **img_norm_cfg),
30+
dict(type='ImageToTensor', keys=['img']),
31+
dict(type='Collect', keys=['img']),
32+
])
33+
]
34+
data = dict(
35+
samples_per_gpu=4,
36+
workers_per_gpu=4,
37+
train=dict(
38+
type=dataset_type,
39+
data_root=data_root,
40+
img_dir='imgs/train',
41+
ann_dir='anno/train',
42+
pipeline=train_pipeline),
43+
val=dict(
44+
type=dataset_type,
45+
data_root=data_root,
46+
img_dir='imgs/val',
47+
ann_dir='anno/val',
48+
pipeline=test_pipeline),
49+
test=dict(
50+
type=dataset_type,
51+
data_root=data_root,
52+
img_dir='imgs/val',
53+
ann_dir='anno/val',
54+
pipeline=test_pipeline))

configs/_base_/default_runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# yapf:disable
22
log_config = dict(
3-
interval=200,
3+
interval=100,
44
hooks=[
55
dict(type='TextLoggerHook', by_epoch=False),
66
# dict(type='TensorboardLoggerHook')

configs/_base_/schedules/schedule_80k_new.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
# runtime settings
77
total_iters = 80000
88
checkpoint_config = dict(by_epoch=False, interval=8000)
9-
evaluation = dict(interval=80000, metric='mIoU')
9+
evaluation = dict(interval=800, metric='mIoU')
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
_base_ = [
2+
'../_base_/models/fast_scnn.py', '../_base_/datasets/doc_seg.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_20k.py'
4+
]
5+
6+
# Re-config the data sampler.
7+
data = dict(samples_per_gpu=4, workers_per_gpu=4)
8+
9+
# Re-config the optimizer.
10+
optimizer = dict(type='SGD', lr=0.12, momentum=0.9, weight_decay=4e-5)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
_base_ = [
2+
'../_base_/models/fast_scnn.py', '../_base_/datasets/doc_seg.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k_new.py'
4+
]
5+
6+
# Re-config the data sampler.
7+
data = dict(samples_per_gpu=8, workers_per_gpu=4)
8+
9+
# Re-config the optimizer.
10+
optimizer = dict(type='SGD', lr=0.12, momentum=0.9, weight_decay=4e-5)

mmseg/apis/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def train_segmentor(model,
118118
eval_cfg = cfg.get('evaluation', {})
119119
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
120120
eval_hook = DistEvalHook if distributed else EvalHook
121-
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
121+
runner.register_hook(eval_hook(val_dataloader, **eval_cfg), priority='LOW')
122122

123123
if cfg.resume_from:
124124
runner.resume(cfg.resume_from)

mmseg/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
from .publaynet_split1 import publaynet_split1Dataset
1515
from .doc_structure1 import doc_structure1Dataset
1616
from .popular_doc import popular_docDataset
17+
from .doc_seg import doc_segDataset
1718

1819
__all__ = [
1920
'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
2021
'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
2122
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
2223
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
23-
'STAREDataset','popular_docDataset'
24+
'STAREDataset','popular_docDataset','doc_segDataset'
2425
]

mmseg/datasets/doc_seg.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from .builder import DATASETS
2+
from .custom import CustomDataset
3+
4+
5+
@DATASETS.register_module()
6+
class doc_segDataset(CustomDataset):
7+
"""doc_segDataset
8+
"""
9+
CLASSES = ('background','doc')
10+
PALETTE = [[120,120,120],[255, 0, 0]]
11+
def __init__(self, **kwargs):
12+
super(doc_segDataset, self).__init__(
13+
img_suffix='.jpg',
14+
seg_map_suffix='.png',
15+
reduce_zero_label=False,
16+
**kwargs)

tools/labelme2normal.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import os, cv2
2+
import numpy as np
3+
import json
4+
5+
def get_list_file_in_folder(dir, ext=['jpg', 'png', 'JPG', 'PNG']):
6+
included_extensions = ext
7+
file_names = [fn for fn in os.listdir(dir)
8+
if any(fn.endswith(ext) for ext in included_extensions)]
9+
return file_names
10+
11+
def convert_labelme_label_to_normal_format(src_anno_dir, src_img_dir, dst_anno_dir, label_list, debug=False):
12+
print('convert_voc_label_to_normal_format')
13+
print('src_anno_dir',src_anno_dir)
14+
print('dst_anno_dir',dst_anno_dir)
15+
16+
list_imgs = get_list_file_in_folder(src_img_dir)
17+
list_imgs = sorted(list_imgs)
18+
19+
count_samples ={}
20+
for label in label_list:
21+
count_samples[label]=0
22+
23+
for idx, img_name in enumerate(list_imgs):
24+
base_name = img_name.split('.')[0]
25+
if idx < 0:
26+
continue
27+
print(idx, 'labelme2normal. Convert', base_name)
28+
29+
json_path = os.path.join(src_anno_dir, base_name+'.json')
30+
img = cv2.imread(os.path.join(src_img_dir, img_name))
31+
32+
segment_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
33+
with open(json_path) as json_file:
34+
data = json.load(json_file)
35+
shapes = data["shapes"]
36+
for shape in shapes:
37+
point = shape["points"]
38+
label = shape["label"]
39+
pts = np.asarray(point,np.int32)
40+
# label_idx = label_list.index(label)
41+
label_idx = 1 # chỉ có vùng giấy tờ với idx =1
42+
if label not in count_samples.keys():
43+
count_samples[label] = 0
44+
count_samples[label] +=1
45+
46+
# color = int(20*label_idx)
47+
color = label_idx
48+
49+
cv2.fillPoly(segment_img, pts=[pts], color=color)
50+
if debug:
51+
cv2.imshow('origin' ,img)
52+
cv2.imshow('mask' ,segment_img)
53+
cv2.waitKey(0)
54+
55+
56+
output_anno_path = os.path.join(dst_anno_dir, base_name+'.png')
57+
cv2.imwrite(output_anno_path, segment_img)
58+
59+
print('Number of samples', count_samples)
60+
61+
62+
if __name__ == "__main__":
63+
# test = cv2.imread('/home/duycuong/PycharmProjects/ocr/others/conversion_tools/segmentation/00004.png')
64+
65+
src_anno_dir ='/data_backup/cuongnd/mmseg/doc_seg/anno/bhyt'
66+
src_img_dir = '/data_backup/cuongnd/Viettel_freeform/MAFC/BHYT_origin/imgs/clean'
67+
dst_anno_dir ='/data_backup/cuongnd/mmseg/doc_seg/anno/bhyt_imgs'
68+
label_list = ['background','cccd','cccd_back','cmnd_new','cmnd_old','cmnd_old_back',
69+
'driverlicense_new','driverlicense_new_back','driverlicense_old','driverlicense_old_back']
70+
convert_labelme_label_to_normal_format(src_anno_dir,
71+
src_img_dir,
72+
dst_anno_dir,
73+
label_list,
74+
debug = False)

tools/prepare_segmentation_data.py

Lines changed: 7 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -57,26 +57,6 @@ def convert_anno_detection_to_segmentation(img_dir, anno_det_dir, output_anno_se
5757
cv2.rectangle(anno_mask,(int(left)-extend,int(top)-extend),(int(right)+extend,int(bottom)+extend),1,-1)
5858
cv2.imwrite(os.path.join(output_anno_segment_dir,img_name),anno_mask)
5959

60-
def convert_anno_objective2_to_segmentation(img_dir, anno_det_dir, output_anno_segment_dir, extend=-1, format_anno_det='icdar', class_list=dict()):
61-
list_images = get_list_file_in_folder(img_dir)
62-
list_images = sorted(list_images)
63-
for idx, img_name in enumerate(list_images):
64-
print(idx, img_name)
65-
img_path=os.path.join(img_dir,img_name)
66-
img = cv2.imread(img_path)
67-
anno_mask = np.zeros((img.shape[0], img.shape[1]), np.uint8)
68-
anno_file = os.path.join(anno_det_dir,img_name.replace('.jpg','.json').replace('.png','.json'))
69-
70-
import json
71-
with open(anno_file, "r") as anno:
72-
anno_str = json.load(anno)
73-
74-
for i, line in enumerate(anno_str['cellboxes']):
75-
left, top, right, bottom = line[0], line[1], line[2], line[3]
76-
cv2.rectangle(anno_mask,(int(left)-extend,int(top)-extend),(int(right)+extend,int(bottom)+extend),1,-1)
77-
cv2.imwrite(os.path.join(output_anno_segment_dir,img_name),anno_mask)
78-
print('ok')
79-
8060
def split_dataset(img_dir, ann_dir, img_dst_dir, ann_dst_dir, ratio=0.5):
8161
list_images = get_list_file_in_folder(img_dir)
8262
random.shuffle(list_images)
@@ -260,41 +240,20 @@ def visualize_normal_format_dataset(img_dir, ann_dir):
260240

261241

262242
if __name__=='__main__':
263-
#img=cv2.imread('/home/cuongnd/PycharmProjects/aicr/source/mmsegmentation/data/ade/ADEChallengeData2016/annotations/validation/ADE_val_00000012.png', cv2.IMREAD_GRAYSCALE)
264-
265-
data_dir='/data20.04/data/table recognition/from_Korea/201012_172754_pubtabnet_valid_sample_objective#2'
266-
img_dir= data_dir + '/images'
267-
anno_det_dir=data_dir + '/annots'
268-
output_anno_segment_dir=data_dir + '/annot_seg'
269-
270-
271-
#convert_anno_objective2_to_segmentation(img_dir, anno_det_dir, output_anno_segment_dir)
272-
#
273243
# split_dataset(img_dir='/data4T/cuongnd/dataset/publaynet_split1/img_dir/train',
274244
# ann_dir='/data4T/cuongnd/dataset/publaynet_split1/ann_dir/train_3classes',
275245
# img_dst_dir='/data4T/cuongnd/dataset/doc_structure1/img_dir/train',
276246
# ann_dst_dir='/data4T/cuongnd/dataset/doc_structure1/ann_dir/train',
277247
# ratio=0.002)
278248

279-
# del_dataset(img_dir='/data20.04/data/doc_structure/publaynet/img_dir/train',
280-
# ann_dir='/data20.04/data/doc_structure/publaynet/ann_dir/train')
281-
282-
src_anno_dir='/data4T/cuongnd/dataset/publaynet_split1/ann_dir/val'
283-
dst_anno_dir='/data4T/cuongnd/dataset/publaynet_split1/ann_dir/val_3classes'
284-
# refactor_classes_of_dataset(src_anno_dir, dst_anno_dir,
285-
# src_classes=[1, 2, 3, 4, 5], #('text', 'title', 'list', 'table', 'figure')
286-
# dst_classes=[1, 1, 3, 2, 1])
287-
249+
src_anno_dir='/data_backup/cuongnd/Viettel_freeform/MAFC/BHYT_origin/imgs/clean'
250+
dst_anno_dir='/data_backup/cuongnd/mmseg/doc_seg/imgs/bhyt'
251+
convert_all_imgs_to_jpg(src_anno_dir,dst_anno_dir)
288252

289-
#onvert_voc_label_to_normal_format(src_anno_dir,dst_anno_dir)
290253

291-
#convert_all_imgs_to_jpg(src_anno_dir,dst_anno_dir)
292-
#
293-
# refine_dataset(img_dir='/data4T/ntanh/publaynet/train',
294-
# ann_dir='/data4T/ntanh/publaynet_gen_gt_oct2.1/train/label')
295-
img_dir='/home/duycuong/home_data/mmlab/mmseg/popular_doc/images/train'
296-
ann_dir='/home/duycuong/home_data/mmlab/mmseg/popular_doc/annotations/train'
297-
visualize_normal_format_dataset(img_dir=img_dir,
298-
ann_dir=ann_dir)
254+
# img_dir='/data_backup/cuongnd/mmseg/doc_seg_data/imgs/train'
255+
# ann_dir='/data_backup/cuongnd/mmseg/doc_seg_data/anno/train'
256+
# visualize_normal_format_dataset(img_dir=img_dir,
257+
# ann_dir=ann_dir)
299258

300259

0 commit comments

Comments
 (0)